From: Francois Fleuret Date: Mon, 21 Aug 2017 06:19:09 +0000 (+0200) Subject: Style variable nodes differently, shows the tensor size, invoke the dot command in... X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d9e6125c82f4e3775b8c868751b1657a6a147f55;p=agtree2dot.git Style variable nodes differently, shows the tensor size, invoke the dot command in mlp.py. --- diff --git a/README.md b/README.md index 6b996b4..452aa9c 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,17 @@ from a [pytorch](http://pytorch.org) autograd graph. ### agtree2dot.save_dot(variable, variable_labels, result_file) ### -Saves into `result_file` a dot file corresponding to the autograd graph for `variable`, which can be either a single `Variable` or a set of `Variable`s. The dictionary `variable_labels` associates strings to some variables, which will be used in the resulting graph. +Saves into `result_file` a dot file corresponding to the autograd +graph for the `Variable` `variable`. The dictionary `variable_labels` +associates strings to some variables, which will be used in the +resulting graph. ## Example ## -A typical use would be: +A typical use is provided in [mlp.py](https://fleuret.org/git-extract/agtree2dot/mlp.py): ```python -import torch +import subprocess from torch import nn from torch.nn import functional as fn @@ -47,15 +50,28 @@ criterion = nn.MSELoss() loss = criterion(output, target) agtree2dot.save_dot(loss, - { input: 'input', target: 'target', loss: 'loss' }, + { + input: 'input', + target: 'target', + loss: 'loss', + mlp.fc1.weight: 'weight1', + mlp.fc1.bias: 'bias1', + mlp.fc2.weight: 'weight2', + mlp.fc2.bias: 'bias2', + }, open('./mlp.dot', 'w')) -``` -which would generate a file mlp.dot, which can then be translated to -pdf using the [Graphviz tools](http://www.graphviz.org/) +print('Generated mlp.dot') -``` -dot mlp.dot -Lg -T pdf -o mlp.pdf +try: + subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ]) +except subprocess.CalledProcessError: + print('Calling the dot command failed. Is Graphviz installed?') + sys.exit(1) + +print('Generated mlp.pdf') ``` -to produce [mlp.pdf.](https://fleuret.org/git-extract/agtree2dot/mlp.pdf) +which would generate a file mlp.dot and try to generate +[mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf) from it +with [Graphviz tools.](http://www.graphviz.org/) diff --git a/agtree2dot.py b/agtree2dot.py index 8931e36..8cc9e8c 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -75,7 +75,7 @@ def add_link(node_list, link_list, u, nu, v, nv): ###################################################################### -def build_ag_graph_lists(u, node_labels, node_list, link_list): +def fill_graph_lists(u, node_labels, node_list, link_list): if u is not None and not u in node_list: node = Node(len(node_list) + 1, @@ -84,15 +84,19 @@ def build_ag_graph_lists(u, node_labels, node_list, link_list): node_list[u] = node if isinstance(u, torch.autograd.Variable): - build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list) + fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) - else: - if hasattr(u, 'next_functions'): - i = 0 - for v, j in u.next_functions: - build_ag_graph_lists(v, node_labels, node_list, link_list) - add_link(node_list, link_list, u, i, v, j) - i += 1 + + if hasattr(u, 'variable'): + fill_graph_lists(u.variable, node_labels, node_list, link_list) + add_link(node_list, link_list, u, 0, u.variable, 0) + + if hasattr(u, 'next_functions'): + i = 0 + for v, j in u.next_functions: + fill_graph_lists(v, node_labels, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) + i += 1 ###################################################################### @@ -102,14 +106,22 @@ def print_dot(node_list, link_list, out): for n in node_list: node = node_list[n] - out.write( - ' ' + \ - str(node.id) + ' [shape=record,label="{ ' + \ - slot_string(node.max_out, for_input = True) + \ - node.label + \ - slot_string(node.max_in, for_input = False) + \ - ' }"]\n' - ) + if isinstance(n, torch.autograd.Variable): + out.write( + ' ' + \ + str(node.id) + ' [shape=note,label="' + \ + node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \ + '"]\n' + ) + else: + out.write( + ' ' + \ + str(node.id) + ' [shape=record,label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) for n in link_list: out.write(' ' + \ @@ -124,7 +136,7 @@ def print_dot(node_list, link_list, out): def save_dot(x, node_labels = {}, out = sys.stdout): node_list, link_list = {}, [] - build_ag_graph_lists(x, node_labels, node_list, link_list) + fill_graph_lists(x, node_labels, node_list, link_list) print_dot(node_list, link_list, out) ###################################################################### diff --git a/mlp.pdf b/mlp.pdf index 0f41f81..52dea89 100644 Binary files a/mlp.pdf and b/mlp.pdf differ diff --git a/mlp.py b/mlp.py index 8497848..3c5f026 100755 --- a/mlp.py +++ b/mlp.py @@ -17,6 +17,8 @@ # Contact for comments & bug reports # ######################################################################### +import subprocess + from torch import nn from torch.nn import functional as fn from torch import Tensor @@ -45,8 +47,23 @@ criterion = nn.MSELoss() loss = criterion(output, target) agtree2dot.save_dot(loss, - { input: 'input', target: 'target', loss: 'loss' }, + { + input: 'input', + target: 'target', + loss: 'loss', + mlp.fc1.weight: 'weight1', + mlp.fc1.bias: 'bias1', + mlp.fc2.weight: 'weight2', + mlp.fc2.bias: 'bias2', + }, open('./mlp.dot', 'w')) -print('Generated mlp.dot. You can convert it to pdf with') -print('> dot mlp.dot -Lg -T pdf -o mlp.pdf') +print('Generated mlp.dot') + +try: + subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ]) +except subprocess.CalledProcessError: + print('Calling the dot command failed. Is Graphviz installed?') + sys.exit(1) + +print('Generated mlp.pdf')