X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=agtree2dot.py;h=4eef05a4fca1cd967af795b817a3827fd0cdd7d8;hb=75c4f522cda71ef830148ec8d881157322be9502;hp=8cc9e8cb4cad429a3c8b4b338bf16b9619f43d01;hpb=d9e6125c82f4e3775b8c868751b1657a6a147f55;p=agtree2dot.git diff --git a/agtree2dot.py b/agtree2dot.py index 8cc9e8c..4eef05a 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -83,7 +83,7 @@ def fill_graph_lists(u, node_labels, node_list, link_list): re.search('', str(type(u))).group(2)) node_list[u] = node - if isinstance(u, torch.autograd.Variable): + if hasattr(u, 'grad_fn'): fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) @@ -103,20 +103,24 @@ def fill_graph_lists(u, node_labels, node_list, link_list): def print_dot(node_list, link_list, out): out.write('digraph{\n') + out.write(' graph [fontname = "helvetica"];\n') + out.write(' node [fontname = "helvetica"];\n') + out.write(' edge [fontname = "helvetica"];\n') + for n in node_list: node = node_list[n] if isinstance(n, torch.autograd.Variable): out.write( ' ' + \ - str(node.id) + ' [shape=note,label="' + \ + str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \ node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \ '"]\n' ) else: out.write( ' ' + \ - str(node.id) + ' [shape=record,label="{ ' + \ + str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \ slot_string(node.max_out, for_input = True) + \ node.label + \ slot_string(node.max_in, for_input = False) + \