Update.
[agtree2dot.git] / mlp.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 import subprocess
21
22 from torch import nn
23 from torch.nn import functional as fn
24 from torch import Tensor
25 from torch.nn import Module
26
27 import agtree2dot
28
29 class MLP(Module):
30     def __init__(self, input_dim, hidden_dim, output_dim):
31         super(MLP, self).__init__()
32         self.fc1 = nn.Linear(input_dim, hidden_dim)
33         self.fc2 = nn.Linear(hidden_dim, output_dim)
34
35     def forward(self, x):
36         x = self.fc1(x)
37         x = fn.tanh(x)
38         x = self.fc2(x)
39         return x
40
41 mlp = MLP(10, 20, 1)
42 input = Tensor(100, 10).normal_()
43 target = Tensor(100, 1).normal_()
44 output = mlp(input)
45 criterion = nn.MSELoss()
46 loss = criterion(output, target)
47
48 agtree2dot.save_dot(loss,
49                     {
50                         input: 'input',
51                         target: 'target',
52                         loss: 'loss',
53                         mlp.fc1.weight: 'weight1',
54                         mlp.fc1.bias: 'bias1',
55                         mlp.fc2.weight: 'weight2',
56                         mlp.fc2.bias: 'bias2',
57                     },
58                     open('./mlp.dot', 'w'))
59
60 print('Generated mlp.dot')
61
62 try:
63
64     fontname = 'Computer Modern'
65     fontsize = 12
66     subprocess.check_call(['dot', 'mlp.dot',
67                            '-Lg',
68                            '-T', 'pdf',
69                            '-Efontname=' + fontname, '-Efontsize=' + str(fontsize),
70                            '-Nfontname=' + fontname, '-Nfontsize=' + str(fontsize),
71                            '-o', 'mlp.pdf' ])
72
73 except subprocess.CalledProcessError:
74
75     print('Calling the dot command failed. Is Graphviz installed?')
76     sys.exit(1)
77
78 print('Generated mlp.pdf')