Complete re-write for pytorch 0.2.0
authorFrancois Fleuret <francois@fleuret.org>
Sun, 20 Aug 2017 20:17:44 +0000 (22:17 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 20 Aug 2017 20:17:44 +0000 (22:17 +0200)
agtree2dot.py
mlp.pdf

index f215f94..2d89af5 100755 (executable)
-
-#########################################################################
-# This program is free software: you can redistribute it and/or modify  #
-# it under the terms of the version 3 of the GNU General Public License #
-# as published by the Free Software Foundation.                         #
-#                                                                       #
-# This program is distributed in the hope that it will be useful, but   #
-# WITHOUT ANY WARRANTY; without even the implied warranty of            #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
-# General Public License for more details.                              #
-#                                                                       #
-# You should have received a copy of the GNU General Public License     #
-# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
-#                                                                       #
-# Written by and Copyright (C) Francois Fleuret                         #
-# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
-#########################################################################
+#!/usr/bin/env python-for-pytorch
 
 import torch
-import re
-import sys
+import math, sys, re
 
-import torch.autograd
+from torch import nn
+from torch.nn import functional as fn
 
-######################################################################
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+from torch.nn import Module
 
-def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
+######################################################################
 
-    if isinstance(x, set):
+class Link:
+    def __init__(self, from_node, from_nb, to_node, to_nb):
+        self.from_node = from_node
+        self.from_nb = from_nb
+        self.to_node = to_node
+        self.to_nb = to_nb
+
+class Node:
+    def __init__(self, id, label):
+        self.id = id
+        self.label = label
+        self.max_in = -1
+        self.max_out = -1
+
+def slot(node_list, n, k, for_input):
+    if for_input:
+        if node_list[n].max_out > 0:
+            return str(node_list[n].id) + ':input' + str(k)
+        else:
+            return str(node_list[n].id)
+    else:
+        if node_list[n].max_in > 0:
+            return str(node_list[n].id) + ':output' + str(k)
+        else:
+            return str(node_list[n].id)
 
-        for y in x:
-            save_dot_rec(y, node_labels, out, drawn_node_id)
+def slot_string(k, for_input):
+    result = ''
 
+    if for_input:
+        label = 'input'
     else:
+        label = 'output'
+
+    if k > 0:
+        if not for_input: result = ' |' + result
+        result +=  ' { <' + label + '0> 0'
+        for j in range(1, k+1):
+            result += " | " + '<' + label + str(j) + '> ' + str(j)
+        result += " } "
+        if for_input: result = result + '| '
 
-        if not x in drawn_node_id:
-            drawn_node_id[x] = len(drawn_node_id) + 1
-
-            # Draw the node (Variable or Function) if not already
-            # drawn
-
-            if isinstance(x, torch.autograd.Variable):
-                name = ((x in node_labels and node_labels[x]) or 'Variable')
-                # Add the tensor size
-                name = name + ' ['
-                for d in range(0, x.data.dim()):
-                    if d > 0: name = name + ', '
-                    name = name + str(x.data.size(d))
-                name = name + ']'
-
-                out.write('  ' + str(drawn_node_id[x]) +
-                          ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
-
-                if hasattr(x, 'creator') and x.creator:
-                    y = x.creator
-                    save_dot_rec(y, node_labels, out, drawn_node_id)
-                    # Edge to the creator
-                    out.write('  ' + str(drawn_node_id[y]) + ' -> ' +  str(drawn_node_id[x]) + '\n')
-
-            elif isinstance(x, torch.autograd.Function):
-                name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
-                       re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
-
-                prefix = ''
-                suffix = ''
-
-                if hasattr(x, 'num_inputs') and x.num_inputs > 1:
-                    prefix = '{ '
-                    for i in range(0, x.num_inputs):
-                        if i > 0: prefix = prefix + ' | '
-                        prefix = prefix + '<input' + str(i) + '> ' + str(i)
-                    prefix = prefix + ' } | '
-
-                if hasattr(x, 'num_outputs') and x.num_outputs > 1:
-                    suffix = ' | { '
-                    for i in range(0, x.num_outputs):
-                        if i > 0: suffix = suffix + ' | '
-                        suffix = suffix + '<output' + str(i) + '> ' + str(i)
-                    suffix = suffix + ' }'
-
-                out.write('  ' + str(drawn_node_id[x]) + \
-                          ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
-
-            else:
-
-                print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
-                exit(1)
-
-            if hasattr(x, 'num_inputs'):
-                for i in range(0, x.num_inputs):
-                    y = x.previous_functions[i][0]
-                    save_dot_rec(y, node_labels, out, drawn_node_id)
-                    from_str = str(drawn_node_id[y])
-                    if hasattr(y, 'num_outputs') and y.num_outputs > 1:
-                        from_str = from_str + ':output' + str(x.previous_functions[i][1])
-                    to_str   = str(drawn_node_id[x])
-                    if x.num_inputs > 1:
-                        to_str = to_str + ':input' + str(i)
-                    out.write('  ' + from_str + ' -> ' +  to_str + '\n')
+    return result
 
 ######################################################################
 
-def save_dot(x, node_labels = {}, out = sys.stdout):
-    out.write('digraph {\n')
-    save_dot_rec(x, node_labels, out, {})
+def add_link(node_list, link_list, u, nu, v, nv):
+    link = Link(u, nu, v, nv)
+    link_list.append(link)
+    node_list[u].max_in  = max(node_list[u].max_in,  nu)
+    node_list[v].max_out = max(node_list[u].max_out, nv)
+
+######################################################################
+
+def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
+
+    if not u in node_list:
+        node = Node(len(node_list) + 1,
+                    (u in node_labels and node_labels[u]) or \
+                    re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
+        node_list[u] = node
+
+        if isinstance(u, torch.autograd.Variable):
+            build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
+            add_link(node_list, link_list, u, 0, u.grad_fn, 0)
+        else:
+            if hasattr(u, 'next_functions'):
+                i = 0
+                for v, j in u.next_functions:
+                    build_ag_graph_lists(v, node_labels, out, node_list, link_list)
+                    add_link(node_list, link_list, u, i, v, j)
+                    i += 1
+
+######################################################################
+
+def print_dot(node_list, link_list, out):
+    out.write('digraph{\n')
+
+    for n in node_list:
+        node = node_list[n]
+
+        out.write(
+            '  ' + \
+            str(node.id) + ' [shape=record,label="{ ' + \
+            slot_string(node.max_out, for_input = True) + \
+            node.label + \
+            slot_string(node.max_in, for_input = False) + \
+            ' }"]\n'
+        )
+
+    for n in link_list:
+        out.write('  ' + \
+                  slot(node_list, n.from_node, n.from_nb, for_input = False) + \
+                  ' -> ' + \
+                  slot(node_list, n.to_node, n.to_nb, for_input = True) + \
+                  '\n')
+
     out.write('}\n')
 
 ######################################################################
+
+def save_dot(x, node_labels = {}, out = sys.stdout):
+    node_list = {}
+    link_list = []
+    build_ag_graph_lists(x, node_labels, out, node_list, link_list)
+    print_dot(node_list, link_list, out)
+
+######################################################################
+
+# x = Variable(torch.rand(5))
+# y = torch.topk(x, 3)
+# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
+
+# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))
diff --git a/mlp.pdf b/mlp.pdf
index 0f735a5..f233b20 100644 (file)
Binary files a/mlp.pdf and b/mlp.pdf differ