######################################################################
-def build_ag_graph_lists(u, node_labels, node_list, link_list):
+def fill_graph_lists(u, node_labels, node_list, link_list):
if u is not None and not u in node_list:
node = Node(len(node_list) + 1,
node_list[u] = node
if isinstance(u, torch.autograd.Variable):
- build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
+ fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
- else:
- if hasattr(u, 'next_functions'):
- i = 0
- for v, j in u.next_functions:
- build_ag_graph_lists(v, node_labels, node_list, link_list)
- add_link(node_list, link_list, u, i, v, j)
- i += 1
+
+ if hasattr(u, 'variable'):
+ fill_graph_lists(u.variable, node_labels, node_list, link_list)
+ add_link(node_list, link_list, u, 0, u.variable, 0)
+
+ if hasattr(u, 'next_functions'):
+ i = 0
+ for v, j in u.next_functions:
+ fill_graph_lists(v, node_labels, node_list, link_list)
+ add_link(node_list, link_list, u, i, v, j)
+ i += 1
######################################################################
for n in node_list:
node = node_list[n]
- out.write(
- ' ' + \
- str(node.id) + ' [shape=record,label="{ ' + \
- slot_string(node.max_out, for_input = True) + \
- node.label + \
- slot_string(node.max_in, for_input = False) + \
- ' }"]\n'
- )
+ if isinstance(n, torch.autograd.Variable):
+ out.write(
+ ' ' + \
+ str(node.id) + ' [shape=note,label="' + \
+ node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+ '"]\n'
+ )
+ else:
+ out.write(
+ ' ' + \
+ str(node.id) + ' [shape=record,label="{ ' + \
+ slot_string(node.max_out, for_input = True) + \
+ node.label + \
+ slot_string(node.max_in, for_input = False) + \
+ ' }"]\n'
+ )
for n in link_list:
out.write(' ' + \
def save_dot(x, node_labels = {}, out = sys.stdout):
node_list, link_list = {}, []
- build_ag_graph_lists(x, node_labels, node_list, link_list)
+ fill_graph_lists(x, node_labels, node_list, link_list)
print_dot(node_list, link_list, out)
######################################################################