Style variable nodes differently, shows the tensor size, invoke the dot command in...
[agtree2dot] / agtree2dot.py
1 #########################################################################
2 # This program is free software: you can redistribute it and/or modify  #
3 # it under the terms of the version 3 of the GNU General Public License #
4 # as published by the Free Software Foundation.                         #
5 #                                                                       #
6 # This program is distributed in the hope that it will be useful, but   #
7 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
9 # General Public License for more details.                              #
10 #                                                                       #
11 # You should have received a copy of the GNU General Public License     #
12 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
13 #                                                                       #
14 # Written by and Copyright (C) Francois Fleuret                         #
15 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
16 #########################################################################
17
18 import torch
19 import sys, re
20
21 ######################################################################
22
23 class Link:
24     def __init__(self, from_node, from_nb, to_node, to_nb):
25         self.from_node = from_node
26         self.from_nb = from_nb
27         self.to_node = to_node
28         self.to_nb = to_nb
29
30 class Node:
31     def __init__(self, id, label):
32         self.id = id
33         self.label = label
34         self.max_in = -1
35         self.max_out = -1
36
37 def slot(node_list, n, k, for_input):
38     if for_input:
39         if node_list[n].max_out > 0:
40             return str(node_list[n].id) + ':input' + str(k)
41         else:
42             return str(node_list[n].id)
43     else:
44         if node_list[n].max_in > 0:
45             return str(node_list[n].id) + ':output' + str(k)
46         else:
47             return str(node_list[n].id)
48
49 def slot_string(k, for_input):
50     result = ''
51
52     if for_input:
53         label = 'input'
54     else:
55         label = 'output'
56
57     if k > 0:
58         if not for_input: result = ' |' + result
59         result +=  ' { <' + label + '0> 0'
60         for j in range(1, k + 1):
61             result += " | " + '<' + label + str(j) + '> ' + str(j)
62         result += " } "
63         if for_input: result = result + '| '
64
65     return result
66
67 ######################################################################
68
69 def add_link(node_list, link_list, u, nu, v, nv):
70     if u is not None and v is not None:
71         link = Link(u, nu, v, nv)
72         link_list.append(link)
73         node_list[u].max_in  = max(node_list[u].max_in,  nu)
74         node_list[v].max_out = max(node_list[v].max_out, nv)
75
76 ######################################################################
77
78 def fill_graph_lists(u, node_labels, node_list, link_list):
79
80     if u is not None and not u in node_list:
81         node = Node(len(node_list) + 1,
82                     (u in node_labels and node_labels[u]) or \
83                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
84         node_list[u] = node
85
86         if isinstance(u, torch.autograd.Variable):
87             fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
88             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
89
90         if hasattr(u, 'variable'):
91             fill_graph_lists(u.variable, node_labels, node_list, link_list)
92             add_link(node_list, link_list, u, 0, u.variable, 0)
93
94         if hasattr(u, 'next_functions'):
95             i = 0
96             for v, j in u.next_functions:
97                 fill_graph_lists(v, node_labels, node_list, link_list)
98                 add_link(node_list, link_list, u, i, v, j)
99                 i += 1
100
101 ######################################################################
102
103 def print_dot(node_list, link_list, out):
104     out.write('digraph{\n')
105
106     for n in node_list:
107         node = node_list[n]
108
109         if isinstance(n, torch.autograd.Variable):
110             out.write(
111                 '  ' + \
112                 str(node.id) + ' [shape=note,label="' + \
113                 node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
114                 '"]\n'
115             )
116         else:
117             out.write(
118                 '  ' + \
119                 str(node.id) + ' [shape=record,label="{ ' + \
120                 slot_string(node.max_out, for_input = True) + \
121                 node.label + \
122                 slot_string(node.max_in, for_input = False) + \
123                 ' }"]\n'
124             )
125
126     for n in link_list:
127         out.write('  ' + \
128                   slot(node_list, n.from_node, n.from_nb, for_input = False) + \
129                   ' -> ' + \
130                   slot(node_list, n.to_node, n.to_nb, for_input = True) + \
131                   '\n')
132
133     out.write('}\n')
134
135 ######################################################################
136
137 def save_dot(x, node_labels = {}, out = sys.stdout):
138     node_list, link_list = {}, []
139     fill_graph_lists(x, node_labels, node_list, link_list)
140     print_dot(node_list, link_list, out)
141
142 ######################################################################