Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 8 Mar 2017 10:39:09 +0000 (11:39 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 8 Mar 2017 10:39:09 +0000 (11:39 +0100)
README.md [new file with mode: 0644]
agtree2dot.py [new file with mode: 0755]
mlp.py [new file with mode: 0755]

diff --git a/README.md b/README.md
new file mode 100644 (file)
index 0000000..4c219b7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,59 @@
+# Introduction #
+
+This package provides a function that generates a dot file from the
+auto-grad graph.
+
+# Usage #
+
+## Functions ##
+
+### agtree2dot.save_dot(variable, variable_labels, result_file) ###
+
+Saves into `result_file` a dot file corresponding to the auto-grad graph for `variable`, which can be either a single `Variable` or a set of `Variable`s. The dictionary `variable_labels` associates strings to some variables, which will be used in the resulting graph.
+
+## Example ##
+
+A typical use would be:
+
+```python
+import torch
+
+from torch import nn
+from torch.nn import functional as fn
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn import Module
+
+import agtree2dot
+
+class MLP(Module):
+    def __init__(self, input_dim, hidden_dim, output_dim):
+        super(MLP, self).__init__()
+        self.fc1 = nn.Linear(input_dim, hidden_dim)
+        self.fc2 = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = fn.tanh(x)
+        x = self.fc2(x)
+        return x
+
+mlp = MLP(10, 20, 1)
+input = Variable(Tensor(100, 10).normal_())
+target = Variable(Tensor(100).normal_())
+output = mlp(input)
+criterion = nn.MSELoss()
+loss = criterion(output, target)
+
+agtree2dot.save_dot(loss,
+                    { input: 'input', loss: 'loss' },
+                    open('./mlp.dot', 'w'))
+```
+
+which would generate a file mlp.dot, which can then be translated to pdf with
+
+```
+dot mlp.dot -Lg -T pdf -o mlp.pdf
+```
+
+to produce [mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf).
diff --git a/agtree2dot.py b/agtree2dot.py
new file mode 100755 (executable)
index 0000000..f215f94
--- /dev/null
@@ -0,0 +1,108 @@
+
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
+#########################################################################
+
+import torch
+import re
+import sys
+
+import torch.autograd
+
+######################################################################
+
+def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
+
+    if isinstance(x, set):
+
+        for y in x:
+            save_dot_rec(y, node_labels, out, drawn_node_id)
+
+    else:
+
+        if not x in drawn_node_id:
+            drawn_node_id[x] = len(drawn_node_id) + 1
+
+            # Draw the node (Variable or Function) if not already
+            # drawn
+
+            if isinstance(x, torch.autograd.Variable):
+                name = ((x in node_labels and node_labels[x]) or 'Variable')
+                # Add the tensor size
+                name = name + ' ['
+                for d in range(0, x.data.dim()):
+                    if d > 0: name = name + ', '
+                    name = name + str(x.data.size(d))
+                name = name + ']'
+
+                out.write('  ' + str(drawn_node_id[x]) +
+                          ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
+
+                if hasattr(x, 'creator') and x.creator:
+                    y = x.creator
+                    save_dot_rec(y, node_labels, out, drawn_node_id)
+                    # Edge to the creator
+                    out.write('  ' + str(drawn_node_id[y]) + ' -> ' +  str(drawn_node_id[x]) + '\n')
+
+            elif isinstance(x, torch.autograd.Function):
+                name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
+                       re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
+
+                prefix = ''
+                suffix = ''
+
+                if hasattr(x, 'num_inputs') and x.num_inputs > 1:
+                    prefix = '{ '
+                    for i in range(0, x.num_inputs):
+                        if i > 0: prefix = prefix + ' | '
+                        prefix = prefix + '<input' + str(i) + '> ' + str(i)
+                    prefix = prefix + ' } | '
+
+                if hasattr(x, 'num_outputs') and x.num_outputs > 1:
+                    suffix = ' | { '
+                    for i in range(0, x.num_outputs):
+                        if i > 0: suffix = suffix + ' | '
+                        suffix = suffix + '<output' + str(i) + '> ' + str(i)
+                    suffix = suffix + ' }'
+
+                out.write('  ' + str(drawn_node_id[x]) + \
+                          ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
+
+            else:
+
+                print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
+                exit(1)
+
+            if hasattr(x, 'num_inputs'):
+                for i in range(0, x.num_inputs):
+                    y = x.previous_functions[i][0]
+                    save_dot_rec(y, node_labels, out, drawn_node_id)
+                    from_str = str(drawn_node_id[y])
+                    if hasattr(y, 'num_outputs') and y.num_outputs > 1:
+                        from_str = from_str + ':output' + str(x.previous_functions[i][1])
+                    to_str   = str(drawn_node_id[x])
+                    if x.num_inputs > 1:
+                        to_str = to_str + ':input' + str(i)
+                    out.write('  ' + from_str + ' -> ' +  to_str + '\n')
+
+######################################################################
+
+def save_dot(x, node_labels = {}, out = sys.stdout):
+    out.write('digraph {\n')
+    save_dot_rec(x, node_labels, out, {})
+    out.write('}\n')
+
+######################################################################
diff --git a/mlp.py b/mlp.py
new file mode 100755 (executable)
index 0000000..8497848
--- /dev/null
+++ b/mlp.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
+#########################################################################
+
+from torch import nn
+from torch.nn import functional as fn
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn import Module
+
+import agtree2dot
+
+class MLP(Module):
+    def __init__(self, input_dim, hidden_dim, output_dim):
+        super(MLP, self).__init__()
+        self.fc1 = nn.Linear(input_dim, hidden_dim)
+        self.fc2 = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = fn.tanh(x)
+        x = self.fc2(x)
+        return x
+
+mlp = MLP(10, 20, 1)
+input = Variable(Tensor(100, 10).normal_())
+target = Variable(Tensor(100).normal_())
+output = mlp(input)
+criterion = nn.MSELoss()
+loss = criterion(output, target)
+
+agtree2dot.save_dot(loss,
+                    { input: 'input', target: 'target', loss: 'loss' },
+                    open('./mlp.dot', 'w'))
+
+print('Generated mlp.dot. You can convert it to pdf with')
+print('> dot mlp.dot -Lg -T pdf -o mlp.pdf')