Style variable nodes differently, shows the tensor size, invoke the dot command in...
[agtree2dot.git] / agtree2dot.py
index 8931e36..8cc9e8c 100755 (executable)
@@ -75,7 +75,7 @@ def add_link(node_list, link_list, u, nu, v, nv):
 
 ######################################################################
 
-def build_ag_graph_lists(u, node_labels, node_list, link_list):
+def fill_graph_lists(u, node_labels, node_list, link_list):
 
     if u is not None and not u in node_list:
         node = Node(len(node_list) + 1,
@@ -84,15 +84,19 @@ def build_ag_graph_lists(u, node_labels, node_list, link_list):
         node_list[u] = node
 
         if isinstance(u, torch.autograd.Variable):
-            build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
+            fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
-        else:
-            if hasattr(u, 'next_functions'):
-                i = 0
-                for v, j in u.next_functions:
-                    build_ag_graph_lists(v, node_labels, node_list, link_list)
-                    add_link(node_list, link_list, u, i, v, j)
-                    i += 1
+
+        if hasattr(u, 'variable'):
+            fill_graph_lists(u.variable, node_labels, node_list, link_list)
+            add_link(node_list, link_list, u, 0, u.variable, 0)
+
+        if hasattr(u, 'next_functions'):
+            i = 0
+            for v, j in u.next_functions:
+                fill_graph_lists(v, node_labels, node_list, link_list)
+                add_link(node_list, link_list, u, i, v, j)
+                i += 1
 
 ######################################################################
 
@@ -102,14 +106,22 @@ def print_dot(node_list, link_list, out):
     for n in node_list:
         node = node_list[n]
 
-        out.write(
-            '  ' + \
-            str(node.id) + ' [shape=record,label="{ ' + \
-            slot_string(node.max_out, for_input = True) + \
-            node.label + \
-            slot_string(node.max_in, for_input = False) + \
-            ' }"]\n'
-        )
+        if isinstance(n, torch.autograd.Variable):
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=note,label="' + \
+                node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+                '"]\n'
+            )
+        else:
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=record,label="{ ' + \
+                slot_string(node.max_out, for_input = True) + \
+                node.label + \
+                slot_string(node.max_in, for_input = False) + \
+                ' }"]\n'
+            )
 
     for n in link_list:
         out.write('  ' + \
@@ -124,7 +136,7 @@ def print_dot(node_list, link_list, out):
 
 def save_dot(x, node_labels = {}, out = sys.stdout):
     node_list, link_list = {}, []
-    build_ag_graph_lists(x, node_labels, node_list, link_list)
+    fill_graph_lists(x, node_labels, node_list, link_list)
     print_dot(node_list, link_list, out)
 
 ######################################################################