OCD cosmetics.
[agtree2dot.git] / mlp.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 import subprocess
21
22 from torch import nn
23 from torch.nn import functional as fn
24 from torch import Tensor
25 from torch.autograd import Variable
26 from torch.nn import Module
27
28 import agtree2dot
29
30 class MLP(Module):
31     def __init__(self, input_dim, hidden_dim, output_dim):
32         super(MLP, self).__init__()
33         self.fc1 = nn.Linear(input_dim, hidden_dim)
34         self.fc2 = nn.Linear(hidden_dim, output_dim)
35
36     def forward(self, x):
37         x = self.fc1(x)
38         x = fn.tanh(x)
39         x = self.fc2(x)
40         return x
41
42 mlp = MLP(10, 20, 1)
43 input = Variable(Tensor(100, 10).normal_())
44 target = Variable(Tensor(100).normal_())
45 output = mlp(input)
46 criterion = nn.MSELoss()
47 loss = criterion(output, target)
48
49 agtree2dot.save_dot(loss,
50                     {
51                         input: 'input',
52                         target: 'target',
53                         loss: 'loss',
54                         mlp.fc1.weight: 'weight1',
55                         mlp.fc1.bias: 'bias1',
56                         mlp.fc2.weight: 'weight2',
57                         mlp.fc2.bias: 'bias2',
58                     },
59                     open('./mlp.dot', 'w'))
60
61 print('Generated mlp.dot')
62
63 try:
64     fontname='Computer Modern'
65     fontsize=12
66     subprocess.check_call(['dot', 'mlp.dot',
67                            '-Lg',
68                            '-T', 'pdf',
69                            '-Efontname=' + fontname, '-Efontsize=' + str(fontsize),
70                            '-Nfontname=' + fontname, '-Nfontsize=' + str(fontsize),
71                            '-o', 'mlp.pdf' ])
72 except subprocess.CalledProcessError:
73     print('Calling the dot command failed. Is Graphviz installed?')
74     sys.exit(1)
75
76 print('Generated mlp.pdf')