+
+#########################################################################
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation. #
+# #
+# This program is distributed in the hope that it will be useful, but #
+# WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
+# General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <http://www.gnu.org/licenses/>. #
+# #
+# Written by and Copyright (C) Francois Fleuret #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports #
+#########################################################################
+
+import torch
+import re
+import sys
+
+import torch.autograd
+
+######################################################################
+
+def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
+
+ if isinstance(x, set):
+
+ for y in x:
+ save_dot_rec(y, node_labels, out, drawn_node_id)
+
+ else:
+
+ if not x in drawn_node_id:
+ drawn_node_id[x] = len(drawn_node_id) + 1
+
+ # Draw the node (Variable or Function) if not already
+ # drawn
+
+ if isinstance(x, torch.autograd.Variable):
+ name = ((x in node_labels and node_labels[x]) or 'Variable')
+ # Add the tensor size
+ name = name + ' ['
+ for d in range(0, x.data.dim()):
+ if d > 0: name = name + ', '
+ name = name + str(x.data.size(d))
+ name = name + ']'
+
+ out.write(' ' + str(drawn_node_id[x]) +
+ ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
+
+ if hasattr(x, 'creator') and x.creator:
+ y = x.creator
+ save_dot_rec(y, node_labels, out, drawn_node_id)
+ # Edge to the creator
+ out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n')
+
+ elif isinstance(x, torch.autograd.Function):
+ name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
+ re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
+
+ prefix = ''
+ suffix = ''
+
+ if hasattr(x, 'num_inputs') and x.num_inputs > 1:
+ prefix = '{ '
+ for i in range(0, x.num_inputs):
+ if i > 0: prefix = prefix + ' | '
+ prefix = prefix + '<input' + str(i) + '> ' + str(i)
+ prefix = prefix + ' } | '
+
+ if hasattr(x, 'num_outputs') and x.num_outputs > 1:
+ suffix = ' | { '
+ for i in range(0, x.num_outputs):
+ if i > 0: suffix = suffix + ' | '
+ suffix = suffix + '<output' + str(i) + '> ' + str(i)
+ suffix = suffix + ' }'
+
+ out.write(' ' + str(drawn_node_id[x]) + \
+ ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
+
+ else:
+
+ print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
+ exit(1)
+
+ if hasattr(x, 'num_inputs'):
+ for i in range(0, x.num_inputs):
+ y = x.previous_functions[i][0]
+ save_dot_rec(y, node_labels, out, drawn_node_id)
+ from_str = str(drawn_node_id[y])
+ if hasattr(y, 'num_outputs') and y.num_outputs > 1:
+ from_str = from_str + ':output' + str(x.previous_functions[i][1])
+ to_str = str(drawn_node_id[x])
+ if x.num_inputs > 1:
+ to_str = to_str + ':input' + str(i)
+ out.write(' ' + from_str + ' -> ' + to_str + '\n')
+
+######################################################################
+
+def save_dot(x, node_labels = {}, out = sys.stdout):
+ out.write('digraph {\n')
+ save_dot_rec(x, node_labels, out, {})
+ out.write('}\n')
+
+######################################################################