Log the number of parameters.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 from colorama import Fore, Back, Style
27
28 import torch
29
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
33 from torch import nn
34 from torch.nn import functional as fn
35 from torchvision import datasets, transforms, utils
36
37 import svrt
38
39 ######################################################################
40
41 parser = argparse.ArgumentParser(
42     description = 'Simple convnet test on the SVRT.',
43     formatter_class = argparse.ArgumentDefaultsHelpFormatter
44 )
45
46 parser.add_argument('--nb_train_samples',
47                     type = int, default = 100000,
48                     help = 'How many samples for train')
49
50 parser.add_argument('--nb_test_samples',
51                     type = int, default = 10000,
52                     help = 'How many samples for test')
53
54 parser.add_argument('--nb_epochs',
55                     type = int, default = 25,
56                     help = 'How many training epochs')
57
58 parser.add_argument('--log_file',
59                     type = str, default = 'cnn-svrt.log',
60                     help = 'Log file name')
61
62 args = parser.parse_args()
63
64 ######################################################################
65
66 log_file = open(args.log_file, 'w')
67
68 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
69
70 def log_string(s):
71     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
72         str(problem_number) + ' ' + s
73     log_file.write(s + '\n')
74     log_file.flush()
75     print(s)
76
77 ######################################################################
78
79 def generate_set(p, n):
80     target = torch.LongTensor(n).bernoulli_(0.5)
81     t = time.time()
82     input = svrt.generate_vignettes(p, target)
83     t = time.time() - t
84     log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86     return Variable(input), Variable(target)
87
88 ######################################################################
89
90 # Afroze's ShallowNet
91
92 #                    map size   nb. maps
93 #                  ----------------------
94 #                    128x128    1
95 # -- conv(21x21)  -> 108x108    6
96 # -- max(2x2)     -> 54x54      6
97 # -- conv(19x19)  -> 36x36      16
98 # -- max(2x2)     -> 18x18      16
99 # -- conv(18x18)  -> 1x1        120
100 # -- reshape      -> 120        1
101 # -- full(120x84) -> 84         1
102 # -- full(84x2)   -> 2          1
103
104 class Net(nn.Module):
105     def __init__(self):
106         super(Net, self).__init__()
107         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110         self.fc1 = nn.Linear(120, 84)
111         self.fc2 = nn.Linear(84, 2)
112
113     def forward(self, x):
114         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116         x = fn.relu(self.conv3(x))
117         x = x.view(-1, 120)
118         x = fn.relu(self.fc1(x))
119         x = self.fc2(x)
120         return x
121
122 def train_model(train_input, train_target):
123     model, criterion = Net(), nn.CrossEntropyLoss()
124
125     nb_parameters = 0
126     for p in model.parameters():
127         nb_parameters += p.numel()
128     log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
129
130     if torch.cuda.is_available():
131         model.cuda()
132         criterion.cuda()
133
134     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
135
136     for k in range(0, args.nb_epochs):
137         acc_loss = 0.0
138         for b in range(0, train_input.size(0), bs):
139             output = model.forward(train_input.narrow(0, b, bs))
140             loss = criterion(output, train_target.narrow(0, b, bs))
141             acc_loss = acc_loss + loss.data[0]
142             model.zero_grad()
143             loss.backward()
144             optimizer.step()
145         log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
146
147     return model
148
149 ######################################################################
150
151 def nb_errors(model, data_input, data_target, bs = 100):
152     ne = 0
153
154     for b in range(0, data_input.size(0), bs):
155         output = model.forward(data_input.narrow(0, b, bs))
156         wta_prediction = output.data.max(1)[1].view(-1)
157
158         for i in range(0, bs):
159             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
160                 ne = ne + 1
161
162     return ne
163
164 ######################################################################
165
166 for problem_number in range(1, 24):
167     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
168     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
169
170     if torch.cuda.is_available():
171         train_input, train_target = train_input.cuda(), train_target.cuda()
172         test_input, test_target = test_input.cuda(), test_target.cuda()
173
174     mu, std = train_input.data.mean(), train_input.data.std()
175     train_input.data.sub_(mu).div_(std)
176     test_input.data.sub_(mu).div_(std)
177
178     model = train_model(train_input, train_target)
179
180     nb_train_errors = nb_errors(model, train_input, train_target)
181
182     log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
183         100 * nb_train_errors / train_input.size(0),
184         nb_train_errors,
185         train_input.size(0))
186     )
187
188     nb_test_errors = nb_errors(model, test_input, test_target)
189
190     log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
191         100 * nb_test_errors / test_input.size(0),
192         nb_test_errors,
193         test_input.size(0))
194     )
195
196 ######################################################################