Style variable nodes differently, shows the tensor size, invoke the dot command in...
[agtree2dot] / README.md
1 # Introduction #
2
3 This package provides a function that generates a
4 [dot file](https://en.wikipedia.org/wiki/DOT_(graph_description_language))
5 from a [pytorch](http://pytorch.org) autograd graph.
6
7 # Usage #
8
9 ## Functions ##
10
11 ### agtree2dot.save_dot(variable, variable_labels, result_file) ###
12
13 Saves into `result_file` a dot file corresponding to the autograd
14 graph for the `Variable` `variable`. The dictionary `variable_labels`
15 associates strings to some variables, which will be used in the
16 resulting graph.
17
18 ## Example ##
19
20 A typical use is provided in [mlp.py](https://fleuret.org/git-extract/agtree2dot/mlp.py):
21
22 ```python
23 import subprocess
24
25 from torch import nn
26 from torch.nn import functional as fn
27 from torch import Tensor
28 from torch.autograd import Variable
29 from torch.nn import Module
30
31 import agtree2dot
32
33 class MLP(Module):
34     def __init__(self, input_dim, hidden_dim, output_dim):
35         super(MLP, self).__init__()
36         self.fc1 = nn.Linear(input_dim, hidden_dim)
37         self.fc2 = nn.Linear(hidden_dim, output_dim)
38
39     def forward(self, x):
40         x = self.fc1(x)
41         x = fn.tanh(x)
42         x = self.fc2(x)
43         return x
44
45 mlp = MLP(10, 20, 1)
46 input = Variable(Tensor(100, 10).normal_())
47 target = Variable(Tensor(100).normal_())
48 output = mlp(input)
49 criterion = nn.MSELoss()
50 loss = criterion(output, target)
51
52 agtree2dot.save_dot(loss,
53                     {
54                         input: 'input',
55                         target: 'target',
56                         loss: 'loss',
57                         mlp.fc1.weight: 'weight1',
58                         mlp.fc1.bias: 'bias1',
59                         mlp.fc2.weight: 'weight2',
60                         mlp.fc2.bias: 'bias2',
61                     },
62                     open('./mlp.dot', 'w'))
63
64 print('Generated mlp.dot')
65
66 try:
67     subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ])
68 except subprocess.CalledProcessError:
69     print('Calling the dot command failed. Is Graphviz installed?')
70     sys.exit(1)
71
72 print('Generated mlp.pdf')
73 ```
74
75 which would generate a file mlp.dot and try to generate
76 [mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf) from it
77 with [Graphviz tools.](http://www.graphviz.org/)