1 #!/usr/bin/env python-for-pytorch
7 from torch.nn import functional as fn
9 from torch import Tensor
10 from torch.autograd import Variable
11 from torch.nn.parameter import Parameter
12 from torch.nn import Module
14 ######################################################################
17 def __init__(self, from_node, from_nb, to_node, to_nb):
18 self.from_node = from_node
19 self.from_nb = from_nb
20 self.to_node = to_node
24 def __init__(self, id, label):
30 def slot(node_list, n, k, for_input):
32 if node_list[n].max_out > 0:
33 return str(node_list[n].id) + ':input' + str(k)
35 return str(node_list[n].id)
37 if node_list[n].max_in > 0:
38 return str(node_list[n].id) + ':output' + str(k)
40 return str(node_list[n].id)
42 def slot_string(k, for_input):
51 if not for_input: result = ' |' + result
52 result += ' { <' + label + '0> 0'
53 for j in range(1, k+1):
54 result += " | " + '<' + label + str(j) + '> ' + str(j)
56 if for_input: result = result + '| '
60 ######################################################################
62 def add_link(node_list, link_list, u, nu, v, nv):
63 link = Link(u, nu, v, nv)
64 link_list.append(link)
65 node_list[u].max_in = max(node_list[u].max_in, nu)
66 node_list[v].max_out = max(node_list[u].max_out, nv)
68 ######################################################################
70 def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
72 if not u in node_list:
73 node = Node(len(node_list) + 1,
74 (u in node_labels and node_labels[u]) or \
75 re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
78 if isinstance(u, torch.autograd.Variable):
79 build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
80 add_link(node_list, link_list, u, 0, u.grad_fn, 0)
82 if hasattr(u, 'next_functions'):
84 for v, j in u.next_functions:
85 build_ag_graph_lists(v, node_labels, out, node_list, link_list)
86 add_link(node_list, link_list, u, i, v, j)
89 ######################################################################
91 def print_dot(node_list, link_list, out):
92 out.write('digraph{\n')
99 str(node.id) + ' [shape=record,label="{ ' + \
100 slot_string(node.max_out, for_input = True) + \
102 slot_string(node.max_in, for_input = False) + \
108 slot(node_list, n.from_node, n.from_nb, for_input = False) + \
110 slot(node_list, n.to_node, n.to_nb, for_input = True) + \
115 ######################################################################
117 def save_dot(x, node_labels = {}, out = sys.stdout):
120 build_ag_graph_lists(x, node_labels, out, node_list, link_list)
121 print_dot(node_list, link_list, out)
123 ######################################################################
125 # x = Variable(torch.rand(5))
126 # y = torch.topk(x, 3)
127 # l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
129 # save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))