Update.
[agtree2dot.git] / agtree2dot.py
1
2 #########################################################################
3 # This program is free software: you can redistribute it and/or modify  #
4 # it under the terms of the version 3 of the GNU General Public License #
5 # as published by the Free Software Foundation.                         #
6 #                                                                       #
7 # This program is distributed in the hope that it will be useful, but   #
8 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
10 # General Public License for more details.                              #
11 #                                                                       #
12 # You should have received a copy of the GNU General Public License     #
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
14 #                                                                       #
15 # Written by and Copyright (C) Francois Fleuret                         #
16 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
17 #########################################################################
18
19 import torch
20 import re
21 import sys
22
23 import torch.autograd
24
25 ######################################################################
26
27 def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
28
29     if isinstance(x, set):
30
31         for y in x:
32             save_dot_rec(y, node_labels, out, drawn_node_id)
33
34     else:
35
36         if not x in drawn_node_id:
37             drawn_node_id[x] = len(drawn_node_id) + 1
38
39             # Draw the node (Variable or Function) if not already
40             # drawn
41
42             if isinstance(x, torch.autograd.Variable):
43                 name = ((x in node_labels and node_labels[x]) or 'Variable')
44                 # Add the tensor size
45                 name = name + ' ['
46                 for d in range(0, x.data.dim()):
47                     if d > 0: name = name + ', '
48                     name = name + str(x.data.size(d))
49                 name = name + ']'
50
51                 out.write('  ' + str(drawn_node_id[x]) +
52                           ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
53
54                 if hasattr(x, 'creator') and x.creator:
55                     y = x.creator
56                     save_dot_rec(y, node_labels, out, drawn_node_id)
57                     # Edge to the creator
58                     out.write('  ' + str(drawn_node_id[y]) + ' -> ' +  str(drawn_node_id[x]) + '\n')
59
60             elif isinstance(x, torch.autograd.Function):
61                 name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
62                        re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
63
64                 prefix = ''
65                 suffix = ''
66
67                 if hasattr(x, 'num_inputs') and x.num_inputs > 1:
68                     prefix = '{ '
69                     for i in range(0, x.num_inputs):
70                         if i > 0: prefix = prefix + ' | '
71                         prefix = prefix + '<input' + str(i) + '> ' + str(i)
72                     prefix = prefix + ' } | '
73
74                 if hasattr(x, 'num_outputs') and x.num_outputs > 1:
75                     suffix = ' | { '
76                     for i in range(0, x.num_outputs):
77                         if i > 0: suffix = suffix + ' | '
78                         suffix = suffix + '<output' + str(i) + '> ' + str(i)
79                     suffix = suffix + ' }'
80
81                 out.write('  ' + str(drawn_node_id[x]) + \
82                           ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
83
84             else:
85
86                 print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
87                 exit(1)
88
89             if hasattr(x, 'num_inputs'):
90                 for i in range(0, x.num_inputs):
91                     y = x.previous_functions[i][0]
92                     save_dot_rec(y, node_labels, out, drawn_node_id)
93                     from_str = str(drawn_node_id[y])
94                     if hasattr(y, 'num_outputs') and y.num_outputs > 1:
95                         from_str = from_str + ':output' + str(x.previous_functions[i][1])
96                     to_str   = str(drawn_node_id[x])
97                     if x.num_inputs > 1:
98                         to_str = to_str + ':input' + str(i)
99                     out.write('  ' + from_str + ' -> ' +  to_str + '\n')
100
101 ######################################################################
102
103 def save_dot(x, node_labels = {}, out = sys.stdout):
104     out.write('digraph {\n')
105     save_dot_rec(x, node_labels, out, {})
106     out.write('}\n')
107
108 ######################################################################