Fixes and cleanup.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 21 Aug 2017 05:42:22 +0000 (07:42 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 21 Aug 2017 05:42:22 +0000 (07:42 +0200)
agtree2dot.py
mlp.pdf

index 7643986..8931e36 100755 (executable)
@@ -57,7 +57,7 @@ def slot_string(k, for_input):
     if k > 0:
         if not for_input: result = ' |' + result
         result +=  ' { <' + label + '0> 0'
-        for j in range(1, k+1):
+        for j in range(1, k + 1):
             result += " | " + '<' + label + str(j) + '> ' + str(j)
         result += " } "
         if for_input: result = result + '| '
@@ -67,29 +67,30 @@ def slot_string(k, for_input):
 ######################################################################
 
 def add_link(node_list, link_list, u, nu, v, nv):
-    link = Link(u, nu, v, nv)
-    link_list.append(link)
-    node_list[u].max_in  = max(node_list[u].max_in,  nu)
-    node_list[v].max_out = max(node_list[u].max_out, nv)
+    if u is not None and v is not None:
+        link = Link(u, nu, v, nv)
+        link_list.append(link)
+        node_list[u].max_in  = max(node_list[u].max_in,  nu)
+        node_list[v].max_out = max(node_list[v].max_out, nv)
 
 ######################################################################
 
-def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
+def build_ag_graph_lists(u, node_labels, node_list, link_list):
 
-    if not u in node_list:
+    if u is not None and not u in node_list:
         node = Node(len(node_list) + 1,
                     (u in node_labels and node_labels[u]) or \
                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
         node_list[u] = node
 
         if isinstance(u, torch.autograd.Variable):
-            build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
+            build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
         else:
             if hasattr(u, 'next_functions'):
                 i = 0
                 for v, j in u.next_functions:
-                    build_ag_graph_lists(v, node_labels, out, node_list, link_list)
+                    build_ag_graph_lists(v, node_labels, node_list, link_list)
                     add_link(node_list, link_list, u, i, v, j)
                     i += 1
 
@@ -122,15 +123,8 @@ def print_dot(node_list, link_list, out):
 ######################################################################
 
 def save_dot(x, node_labels = {}, out = sys.stdout):
-    node_list = {}
-    link_list = []
-    build_ag_graph_lists(x, node_labels, out, node_list, link_list)
+    node_list, link_list = {}, []
+    build_ag_graph_lists(x, node_labels, node_list, link_list)
     print_dot(node_list, link_list, out)
 
 ######################################################################
-
-# x = Variable(torch.rand(5))
-# y = torch.topk(x, 3)
-# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
-
-# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))
diff --git a/mlp.pdf b/mlp.pdf
index 4abdf28..0f41f81 100644 (file)
Binary files a/mlp.pdf and b/mlp.pdf differ