X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=agtree2dot.py;h=8931e366b5f2937ba25330fe4268c4143d387753;hb=55332826c1d0ec125fc1d2db6644c98b1640d4a2;hp=2d89af5cc8130d90e9e8087e549b45f66b69aefd;hpb=35a4a9a26e7b35e507755a5d4fe3ea7f4f1ca6e0;p=agtree2dot.git
diff --git a/agtree2dot.py b/agtree2dot.py
index 2d89af5..8931e36 100755
--- a/agtree2dot.py
+++ b/agtree2dot.py
@@ -1,15 +1,22 @@
-#!/usr/bin/env python-for-pytorch
+#########################################################################
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation. #
+# #
+# This program is distributed in the hope that it will be useful, but #
+# WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
+# General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see . #
+# #
+# Written by and Copyright (C) Francois Fleuret #
+# Contact for comments & bug reports #
+#########################################################################
import torch
-import math, sys, re
-
-from torch import nn
-from torch.nn import functional as fn
-
-from torch import Tensor
-from torch.autograd import Variable
-from torch.nn.parameter import Parameter
-from torch.nn import Module
+import sys, re
######################################################################
@@ -50,7 +57,7 @@ def slot_string(k, for_input):
if k > 0:
if not for_input: result = ' |' + result
result += ' { <' + label + '0> 0'
- for j in range(1, k+1):
+ for j in range(1, k + 1):
result += " | " + '<' + label + str(j) + '> ' + str(j)
result += " } "
if for_input: result = result + '| '
@@ -60,29 +67,30 @@ def slot_string(k, for_input):
######################################################################
def add_link(node_list, link_list, u, nu, v, nv):
- link = Link(u, nu, v, nv)
- link_list.append(link)
- node_list[u].max_in = max(node_list[u].max_in, nu)
- node_list[v].max_out = max(node_list[u].max_out, nv)
+ if u is not None and v is not None:
+ link = Link(u, nu, v, nv)
+ link_list.append(link)
+ node_list[u].max_in = max(node_list[u].max_in, nu)
+ node_list[v].max_out = max(node_list[v].max_out, nv)
######################################################################
-def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
+def build_ag_graph_lists(u, node_labels, node_list, link_list):
- if not u in node_list:
+ if u is not None and not u in node_list:
node = Node(len(node_list) + 1,
(u in node_labels and node_labels[u]) or \
re.search('', str(type(u))).group(2))
node_list[u] = node
if isinstance(u, torch.autograd.Variable):
- build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
+ build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
else:
if hasattr(u, 'next_functions'):
i = 0
for v, j in u.next_functions:
- build_ag_graph_lists(v, node_labels, out, node_list, link_list)
+ build_ag_graph_lists(v, node_labels, node_list, link_list)
add_link(node_list, link_list, u, i, v, j)
i += 1
@@ -115,15 +123,8 @@ def print_dot(node_list, link_list, out):
######################################################################
def save_dot(x, node_labels = {}, out = sys.stdout):
- node_list = {}
- link_list = []
- build_ag_graph_lists(x, node_labels, out, node_list, link_list)
+ node_list, link_list = {}, []
+ build_ag_graph_lists(x, node_labels, node_list, link_list)
print_dot(node_list, link_list, out)
######################################################################
-
-# x = Variable(torch.rand(5))
-# y = torch.topk(x, 3)
-# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
-
-# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))