- if not x in drawn_node_id:
- drawn_node_id[x] = len(drawn_node_id) + 1
-
- # Draw the node (Variable or Function) if not already
- # drawn
-
- if isinstance(x, torch.autograd.Variable):
- name = ((x in node_labels and node_labels[x]) or 'Variable')
- # Add the tensor size
- name = name + ' ['
- for d in range(0, x.data.dim()):
- if d > 0: name = name + ', '
- name = name + str(x.data.size(d))
- name = name + ']'
-
- out.write(' ' + str(drawn_node_id[x]) +
- ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
-
- if hasattr(x, 'creator') and x.creator:
- y = x.creator
- save_dot_rec(y, node_labels, out, drawn_node_id)
- # Edge to the creator
- out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n')
-
- elif isinstance(x, torch.autograd.Function):
- name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
- re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
-
- prefix = ''
- suffix = ''
-
- if hasattr(x, 'num_inputs') and x.num_inputs > 1:
- prefix = '{ '
- for i in range(0, x.num_inputs):
- if i > 0: prefix = prefix + ' | '
- prefix = prefix + '<input' + str(i) + '> ' + str(i)
- prefix = prefix + ' } | '
-
- if hasattr(x, 'num_outputs') and x.num_outputs > 1:
- suffix = ' | { '
- for i in range(0, x.num_outputs):
- if i > 0: suffix = suffix + ' | '
- suffix = suffix + '<output' + str(i) + '> ' + str(i)
- suffix = suffix + ' }'
-
- out.write(' ' + str(drawn_node_id[x]) + \
- ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
-
- else:
-
- print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
- exit(1)
-
- if hasattr(x, 'num_inputs'):
- for i in range(0, x.num_inputs):
- y = x.previous_functions[i][0]
- save_dot_rec(y, node_labels, out, drawn_node_id)
- from_str = str(drawn_node_id[y])
- if hasattr(y, 'num_outputs') and y.num_outputs > 1:
- from_str = from_str + ':output' + str(x.previous_functions[i][1])
- to_str = str(drawn_node_id[x])
- if x.num_inputs > 1:
- to_str = to_str + ':input' + str(i)
- out.write(' ' + from_str + ' -> ' + to_str + '\n')