-
#########################################################################
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the version 3 of the GNU General Public License #
#########################################################################
import torch
-import re
-import sys
+import sys, re
+
+######################################################################
+
+class Link:
+ def __init__(self, from_node, from_nb, to_node, to_nb):
+ self.from_node = from_node
+ self.from_nb = from_nb
+ self.to_node = to_node
+ self.to_nb = to_nb
+
+class Node:
+ def __init__(self, id, label):
+ self.id = id
+ self.label = label
+ self.max_in = -1
+ self.max_out = -1
+
+def slot(node_list, n, k, for_input):
+ if for_input:
+ if node_list[n].max_out > 0:
+ return str(node_list[n].id) + ':input' + str(k)
+ else:
+ return str(node_list[n].id)
+ else:
+ if node_list[n].max_in > 0:
+ return str(node_list[n].id) + ':output' + str(k)
+ else:
+ return str(node_list[n].id)
-import torch.autograd
+def slot_string(k, for_input):
+ result = ''
+
+ if for_input:
+ label = 'input'
+ else:
+ label = 'output'
+
+ if k > 0:
+ if not for_input: result = ' |' + result
+ result += ' { <' + label + '0> 0'
+ for j in range(1, k + 1):
+ result += " | " + '<' + label + str(j) + '> ' + str(j)
+ result += " } "
+ if for_input: result = result + '| '
+
+ return result
######################################################################
-def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
+def add_link(node_list, link_list, u, nu, v, nv):
+ if u is not None and v is not None:
+ link = Link(u, nu, v, nv)
+ link_list.append(link)
+ node_list[u].max_in = max(node_list[u].max_in, nu)
+ node_list[v].max_out = max(node_list[v].max_out, nv)
- if isinstance(x, set):
+######################################################################
- for y in x:
- save_dot_rec(y, node_labels, out, drawn_node_id)
+def fill_graph_lists(u, node_labels, node_list, link_list):
- else:
+ if u is not None and not u in node_list:
+ node = Node(len(node_list) + 1,
+ (u in node_labels and node_labels[u]) or \
+ re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
+ node_list[u] = node
+
+ if hasattr(u, 'grad_fn'):
+ fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
+ add_link(node_list, link_list, u, 0, u.grad_fn, 0)
+
+ if hasattr(u, 'variable'):
+ fill_graph_lists(u.variable, node_labels, node_list, link_list)
+ add_link(node_list, link_list, u, 0, u.variable, 0)
- if not x in drawn_node_id:
- drawn_node_id[x] = len(drawn_node_id) + 1
-
- # Draw the node (Variable or Function) if not already
- # drawn
-
- if isinstance(x, torch.autograd.Variable):
- name = ((x in node_labels and node_labels[x]) or 'Variable')
- # Add the tensor size
- name = name + ' ['
- for d in range(0, x.data.dim()):
- if d > 0: name = name + ', '
- name = name + str(x.data.size(d))
- name = name + ']'
-
- out.write(' ' + str(drawn_node_id[x]) +
- ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
-
- if hasattr(x, 'creator') and x.creator:
- y = x.creator
- save_dot_rec(y, node_labels, out, drawn_node_id)
- # Edge to the creator
- out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n')
-
- elif isinstance(x, torch.autograd.Function):
- name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
- re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
-
- prefix = ''
- suffix = ''
-
- if hasattr(x, 'num_inputs') and x.num_inputs > 1:
- prefix = '{ '
- for i in range(0, x.num_inputs):
- if i > 0: prefix = prefix + ' | '
- prefix = prefix + '<input' + str(i) + '> ' + str(i)
- prefix = prefix + ' } | '
-
- if hasattr(x, 'num_outputs') and x.num_outputs > 1:
- suffix = ' | { '
- for i in range(0, x.num_outputs):
- if i > 0: suffix = suffix + ' | '
- suffix = suffix + '<output' + str(i) + '> ' + str(i)
- suffix = suffix + ' }'
-
- out.write(' ' + str(drawn_node_id[x]) + \
- ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n')
-
- else:
-
- print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).')
- exit(1)
-
- if hasattr(x, 'num_inputs'):
- for i in range(0, x.num_inputs):
- y = x.previous_functions[i][0]
- save_dot_rec(y, node_labels, out, drawn_node_id)
- from_str = str(drawn_node_id[y])
- if hasattr(y, 'num_outputs') and y.num_outputs > 1:
- from_str = from_str + ':output' + str(x.previous_functions[i][1])
- to_str = str(drawn_node_id[x])
- if x.num_inputs > 1:
- to_str = to_str + ':input' + str(i)
- out.write(' ' + from_str + ' -> ' + to_str + '\n')
+ if hasattr(u, 'next_functions'):
+ i = 0
+ for v, j in u.next_functions:
+ fill_graph_lists(v, node_labels, node_list, link_list)
+ add_link(node_list, link_list, u, i, v, j)
+ i += 1
######################################################################
-def save_dot(x, node_labels = {}, out = sys.stdout):
- out.write('digraph {\n')
- save_dot_rec(x, node_labels, out, {})
+def print_dot(node_list, link_list, out):
+ out.write('digraph{\n')
+
+ out.write(' graph [fontname = "helvetica"];\n')
+ out.write(' node [fontname = "helvetica"];\n')
+ out.write(' edge [fontname = "helvetica"];\n')
+
+ for n in node_list:
+ node = node_list[n]
+
+ if isinstance(n, torch.autograd.Variable):
+ out.write(
+ ' ' + \
+ str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
+ node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+ '"]\n'
+ )
+ else:
+ out.write(
+ ' ' + \
+ str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
+ slot_string(node.max_out, for_input = True) + \
+ node.label + \
+ slot_string(node.max_in, for_input = False) + \
+ ' }"]\n'
+ )
+
+ for n in link_list:
+ out.write(' ' + \
+ slot(node_list, n.from_node, n.from_nb, for_input = False) + \
+ ' -> ' + \
+ slot(node_list, n.to_node, n.to_nb, for_input = True) + \
+ '\n')
+
out.write('}\n')
######################################################################
+
+def save_dot(x, node_labels = {}, out = sys.stdout):
+ node_list, link_list = {}, []
+ fill_graph_lists(x, node_labels, node_list, link_list)
+ print_dot(node_list, link_list, out)
+
+######################################################################