From: Francois Fleuret Date: Sun, 20 Aug 2017 20:17:44 +0000 (+0200) Subject: Complete re-write for pytorch 0.2.0 X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=commitdiff_plain;h=35a4a9a26e7b35e507755a5d4fe3ea7f4f1ca6e0 Complete re-write for pytorch 0.2.0 --- diff --git a/agtree2dot.py b/agtree2dot.py index f215f94..2d89af5 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -1,108 +1,129 @@ - -######################################################################### -# This program is free software: you can redistribute it and/or modify # -# it under the terms of the version 3 of the GNU General Public License # -# as published by the Free Software Foundation. # -# # -# This program is distributed in the hope that it will be useful, but # -# WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # -# General Public License for more details. # -# # -# You should have received a copy of the GNU General Public License # -# along with this program. If not, see . # -# # -# Written by and Copyright (C) Francois Fleuret # -# Contact for comments & bug reports # -######################################################################### +#!/usr/bin/env python-for-pytorch import torch -import re -import sys +import math, sys, re -import torch.autograd +from torch import nn +from torch.nn import functional as fn -###################################################################### +from torch import Tensor +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch.nn import Module -def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}): +###################################################################### - if isinstance(x, set): +class Link: + def __init__(self, from_node, from_nb, to_node, to_nb): + self.from_node = from_node + self.from_nb = from_nb + self.to_node = to_node + self.to_nb = to_nb + +class Node: + def __init__(self, id, label): + self.id = id + self.label = label + self.max_in = -1 + self.max_out = -1 + +def slot(node_list, n, k, for_input): + if for_input: + if node_list[n].max_out > 0: + return str(node_list[n].id) + ':input' + str(k) + else: + return str(node_list[n].id) + else: + if node_list[n].max_in > 0: + return str(node_list[n].id) + ':output' + str(k) + else: + return str(node_list[n].id) - for y in x: - save_dot_rec(y, node_labels, out, drawn_node_id) +def slot_string(k, for_input): + result = '' + if for_input: + label = 'input' else: + label = 'output' + + if k > 0: + if not for_input: result = ' |' + result + result += ' { <' + label + '0> 0' + for j in range(1, k+1): + result += " | " + '<' + label + str(j) + '> ' + str(j) + result += " } " + if for_input: result = result + '| ' - if not x in drawn_node_id: - drawn_node_id[x] = len(drawn_node_id) + 1 - - # Draw the node (Variable or Function) if not already - # drawn - - if isinstance(x, torch.autograd.Variable): - name = ((x in node_labels and node_labels[x]) or 'Variable') - # Add the tensor size - name = name + ' [' - for d in range(0, x.data.dim()): - if d > 0: name = name + ', ' - name = name + str(x.data.size(d)) - name = name + ']' - - out.write(' ' + str(drawn_node_id[x]) + - ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n') - - if hasattr(x, 'creator') and x.creator: - y = x.creator - save_dot_rec(y, node_labels, out, drawn_node_id) - # Edge to the creator - out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n') - - elif isinstance(x, torch.autograd.Function): - name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \ - re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1) - - prefix = '' - suffix = '' - - if hasattr(x, 'num_inputs') and x.num_inputs > 1: - prefix = '{ ' - for i in range(0, x.num_inputs): - if i > 0: prefix = prefix + ' | ' - prefix = prefix + ' ' + str(i) - prefix = prefix + ' } | ' - - if hasattr(x, 'num_outputs') and x.num_outputs > 1: - suffix = ' | { ' - for i in range(0, x.num_outputs): - if i > 0: suffix = suffix + ' | ' - suffix = suffix + ' ' + str(i) - suffix = suffix + ' }' - - out.write(' ' + str(drawn_node_id[x]) + \ - ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n') - - else: - - print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).') - exit(1) - - if hasattr(x, 'num_inputs'): - for i in range(0, x.num_inputs): - y = x.previous_functions[i][0] - save_dot_rec(y, node_labels, out, drawn_node_id) - from_str = str(drawn_node_id[y]) - if hasattr(y, 'num_outputs') and y.num_outputs > 1: - from_str = from_str + ':output' + str(x.previous_functions[i][1]) - to_str = str(drawn_node_id[x]) - if x.num_inputs > 1: - to_str = to_str + ':input' + str(i) - out.write(' ' + from_str + ' -> ' + to_str + '\n') + return result ###################################################################### -def save_dot(x, node_labels = {}, out = sys.stdout): - out.write('digraph {\n') - save_dot_rec(x, node_labels, out, {}) +def add_link(node_list, link_list, u, nu, v, nv): + link = Link(u, nu, v, nv) + link_list.append(link) + node_list[u].max_in = max(node_list[u].max_in, nu) + node_list[v].max_out = max(node_list[u].max_out, nv) + +###################################################################### + +def build_ag_graph_lists(u, node_labels, out, node_list, link_list): + + if not u in node_list: + node = Node(len(node_list) + 1, + (u in node_labels and node_labels[u]) or \ + re.search('', str(type(u))).group(2)) + node_list[u] = node + + if isinstance(u, torch.autograd.Variable): + build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list) + add_link(node_list, link_list, u, 0, u.grad_fn, 0) + else: + if hasattr(u, 'next_functions'): + i = 0 + for v, j in u.next_functions: + build_ag_graph_lists(v, node_labels, out, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) + i += 1 + +###################################################################### + +def print_dot(node_list, link_list, out): + out.write('digraph{\n') + + for n in node_list: + node = node_list[n] + + out.write( + ' ' + \ + str(node.id) + ' [shape=record,label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) + + for n in link_list: + out.write(' ' + \ + slot(node_list, n.from_node, n.from_nb, for_input = False) + \ + ' -> ' + \ + slot(node_list, n.to_node, n.to_nb, for_input = True) + \ + '\n') + out.write('}\n') ###################################################################### + +def save_dot(x, node_labels = {}, out = sys.stdout): + node_list = {} + link_list = [] + build_ag_graph_lists(x, node_labels, out, node_list, link_list) + print_dot(node_list, link_list, out) + +###################################################################### + +# x = Variable(torch.rand(5)) +# y = torch.topk(x, 3) +# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float())) + +# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w')) diff --git a/mlp.pdf b/mlp.pdf index 0f735a5..f233b20 100644 Binary files a/mlp.pdf and b/mlp.pdf differ