1 #########################################################################
2 # This program is free software: you can redistribute it and/or modify #
3 # it under the terms of the version 3 of the GNU General Public License #
4 # as published by the Free Software Foundation. #
6 # This program is distributed in the hope that it will be useful, but #
7 # WITHOUT ANY WARRANTY; without even the implied warranty of #
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
9 # General Public License for more details. #
11 # You should have received a copy of the GNU General Public License #
12 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
14 # Written by and Copyright (C) Francois Fleuret #
15 # Contact <francois.fleuret@idiap.ch> for comments & bug reports #
16 #########################################################################
21 ######################################################################
24 def __init__(self, from_node, from_nb, to_node, to_nb):
25 self.from_node = from_node
26 self.from_nb = from_nb
27 self.to_node = to_node
31 def __init__(self, id, label):
37 def slot(node_list, n, k, for_input):
39 if node_list[n].max_out > 0:
40 return str(node_list[n].id) + ':input' + str(k)
42 return str(node_list[n].id)
44 if node_list[n].max_in > 0:
45 return str(node_list[n].id) + ':output' + str(k)
47 return str(node_list[n].id)
49 def slot_string(k, for_input):
58 if not for_input: result = ' |' + result
59 result += ' { <' + label + '0> 0'
60 for j in range(1, k + 1):
61 result += " | " + '<' + label + str(j) + '> ' + str(j)
63 if for_input: result = result + '| '
67 ######################################################################
69 def add_link(node_list, link_list, u, nu, v, nv):
70 if u is not None and v is not None:
71 link = Link(u, nu, v, nv)
72 link_list.append(link)
73 node_list[u].max_in = max(node_list[u].max_in, nu)
74 node_list[v].max_out = max(node_list[v].max_out, nv)
76 ######################################################################
78 def fill_graph_lists(u, node_labels, node_list, link_list):
80 if u is not None and not u in node_list:
81 node = Node(len(node_list) + 1,
82 (u in node_labels and node_labels[u]) or \
83 re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
86 if hasattr(u, 'grad_fn'):
87 fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
88 add_link(node_list, link_list, u, 0, u.grad_fn, 0)
90 if hasattr(u, 'variable'):
91 fill_graph_lists(u.variable, node_labels, node_list, link_list)
92 add_link(node_list, link_list, u, 0, u.variable, 0)
94 if hasattr(u, 'next_functions'):
95 for i, (v, j) in enumerate(u.next_functions):
96 fill_graph_lists(v, node_labels, node_list, link_list)
97 add_link(node_list, link_list, u, i, v, j)
99 ######################################################################
101 def print_dot(node_list, link_list, out):
102 out.write('digraph{\n')
107 if isinstance(n, torch.autograd.Variable):
110 str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
111 node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
117 str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
118 slot_string(node.max_out, for_input = True) + \
120 slot_string(node.max_in, for_input = False) + \
126 slot(node_list, n.from_node, n.from_nb, for_input = False) + \
128 slot(node_list, n.to_node, n.to_nb, for_input = True) + \
133 ######################################################################
135 def save_dot(x, node_labels = {}, out = sys.stdout):
136 node_list, link_list = {}, []
137 fill_graph_lists(x, node_labels, node_list, link_list)
138 print_dot(node_list, link_list, out)
140 ######################################################################