X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=agtree2dot.py;h=8931e366b5f2937ba25330fe4268c4143d387753;hb=55332826c1d0ec125fc1d2db6644c98b1640d4a2;hp=2d89af5cc8130d90e9e8087e549b45f66b69aefd;hpb=35a4a9a26e7b35e507755a5d4fe3ea7f4f1ca6e0;p=agtree2dot.git diff --git a/agtree2dot.py b/agtree2dot.py index 2d89af5..8931e36 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -1,15 +1,22 @@ -#!/usr/bin/env python-for-pytorch +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### import torch -import math, sys, re - -from torch import nn -from torch.nn import functional as fn - -from torch import Tensor -from torch.autograd import Variable -from torch.nn.parameter import Parameter -from torch.nn import Module +import sys, re ###################################################################### @@ -50,7 +57,7 @@ def slot_string(k, for_input): if k > 0: if not for_input: result = ' |' + result result += ' { <' + label + '0> 0' - for j in range(1, k+1): + for j in range(1, k + 1): result += " | " + '<' + label + str(j) + '> ' + str(j) result += " } " if for_input: result = result + '| ' @@ -60,29 +67,30 @@ def slot_string(k, for_input): ###################################################################### def add_link(node_list, link_list, u, nu, v, nv): - link = Link(u, nu, v, nv) - link_list.append(link) - node_list[u].max_in = max(node_list[u].max_in, nu) - node_list[v].max_out = max(node_list[u].max_out, nv) + if u is not None and v is not None: + link = Link(u, nu, v, nv) + link_list.append(link) + node_list[u].max_in = max(node_list[u].max_in, nu) + node_list[v].max_out = max(node_list[v].max_out, nv) ###################################################################### -def build_ag_graph_lists(u, node_labels, out, node_list, link_list): +def build_ag_graph_lists(u, node_labels, node_list, link_list): - if not u in node_list: + if u is not None and not u in node_list: node = Node(len(node_list) + 1, (u in node_labels and node_labels[u]) or \ re.search('', str(type(u))).group(2)) node_list[u] = node if isinstance(u, torch.autograd.Variable): - build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list) + build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) else: if hasattr(u, 'next_functions'): i = 0 for v, j in u.next_functions: - build_ag_graph_lists(v, node_labels, out, node_list, link_list) + build_ag_graph_lists(v, node_labels, node_list, link_list) add_link(node_list, link_list, u, i, v, j) i += 1 @@ -115,15 +123,8 @@ def print_dot(node_list, link_list, out): ###################################################################### def save_dot(x, node_labels = {}, out = sys.stdout): - node_list = {} - link_list = [] - build_ag_graph_lists(x, node_labels, out, node_list, link_list) + node_list, link_list = {}, [] + build_ag_graph_lists(x, node_labels, node_list, link_list) print_dot(node_list, link_list, out) ###################################################################### - -# x = Variable(torch.rand(5)) -# y = torch.topk(x, 3) -# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float())) - -# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))