Replaced the nb of batches arguments with nb of samples.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_samples',
53                     type = int, default = 100000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'default.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 parser.add_argument('--deep_model',
77                     action='store_true', default = False,
78                     help = 'Use Afroze\'s Alexnet-like deep model')
79
80 parser.add_argument('--test_loaded_models',
81                     action='store_true', default = False,
82                     help = 'Should we compute the test errors of loaded models')
83
84 args = parser.parse_args()
85
86 ######################################################################
87
88 log_file = open(args.log_file, 'w')
89 pred_log_t = None
90
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
92
93 # Log and prints the string, with a time stamp. Does not log the
94 # remark
95 def log_string(s, remark = ''):
96     global pred_log_t
97
98     t = time.time()
99
100     if pred_log_t is None:
101         elapsed = 'start'
102     else:
103         elapsed = '+{:.02f}s'.format(t - pred_log_t)
104
105     pred_log_t = t
106
107     s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
108     log_file.write(s + '\n')
109     log_file.flush()
110     print(s + Fore.CYAN + remark + Style.RESET_ALL)
111
112 ######################################################################
113
114 # Afroze's ShallowNet
115
116 #                       map size   nb. maps
117 #                     ----------------------
118 #    input                128x128    1
119 # -- conv(21x21 x 6)   -> 108x108    6
120 # -- max(2x2)          -> 54x54      6
121 # -- conv(19x19 x 16)  -> 36x36      16
122 # -- max(2x2)          -> 18x18      16
123 # -- conv(18x18 x 120) -> 1x1        120
124 # -- reshape           -> 120        1
125 # -- full(120x84)      -> 84         1
126 # -- full(84x2)        -> 2          1
127
128 class AfrozeShallowNet(nn.Module):
129     def __init__(self):
130         super(AfrozeShallowNet, self).__init__()
131         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134         self.fc1 = nn.Linear(120, 84)
135         self.fc2 = nn.Linear(84, 2)
136         self.name = 'shallownet'
137
138     def forward(self, x):
139         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141         x = fn.relu(self.conv3(x))
142         x = x.view(-1, 120)
143         x = fn.relu(self.fc1(x))
144         x = self.fc2(x)
145         return x
146
147 ######################################################################
148
149 # Afroze's DeepNet
150
151 #                       map size   nb. maps
152 #                     ----------------------
153 #    input                128x128    1
154 # -- conv(21x21 x 32 stride=4) -> 28x28    32
155 # -- max(2x2)                  -> 14x14      6
156 # -- conv(7x7 x 96)            -> 8x8      16
157 # -- max(2x2)                  -> 4x4      16
158 # -- conv(5x5 x 96)            -> 26x36      16
159 # -- conv(3x3 x 128)           -> 36x36      16
160 # -- conv(3x3 x 128)           -> 36x36      16
161
162 # -- conv(5x5 x 120) -> 1x1        120
163 # -- reshape           -> 120        1
164 # -- full(3x84)      -> 84         1
165 # -- full(84x2)        -> 2          1
166
167 class AfrozeDeepNet(nn.Module):
168     def __init__(self):
169         super(AfrozeDeepNet, self).__init__()
170         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
171         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
172         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
173         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
174         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
175         self.fc1 = nn.Linear(1536, 256)
176         self.fc2 = nn.Linear(256, 256)
177         self.fc3 = nn.Linear(256, 2)
178         self.name = 'deepnet'
179
180     def forward(self, x):
181         x = self.conv1(x)
182         x = fn.max_pool2d(x, kernel_size=2)
183         x = fn.relu(x)
184
185         x = self.conv2(x)
186         x = fn.max_pool2d(x, kernel_size=2)
187         x = fn.relu(x)
188
189         x = self.conv3(x)
190         x = fn.relu(x)
191
192         x = self.conv4(x)
193         x = fn.relu(x)
194
195         x = self.conv5(x)
196         x = fn.max_pool2d(x, kernel_size=2)
197         x = fn.relu(x)
198
199         x = x.view(-1, 1536)
200
201         x = self.fc1(x)
202         x = fn.relu(x)
203
204         x = self.fc2(x)
205         x = fn.relu(x)
206
207         x = self.fc3(x)
208
209         return x
210
211 ######################################################################
212
213 def train_model(model, train_set):
214     batch_size = args.batch_size
215     criterion = nn.CrossEntropyLoss()
216
217     if torch.cuda.is_available():
218         criterion.cuda()
219
220     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
221
222     start_t = time.time()
223
224     for e in range(0, args.nb_epochs):
225         acc_loss = 0.0
226         for b in range(0, train_set.nb_batches):
227             input, target = train_set.get_batch(b)
228             output = model.forward(Variable(input))
229             loss = criterion(output, Variable(target))
230             acc_loss = acc_loss + loss.data[0]
231             model.zero_grad()
232             loss.backward()
233             optimizer.step()
234         dt = (time.time() - start_t) / (e + 1)
235         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
236                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
237
238     return model
239
240 ######################################################################
241
242 def nb_errors(model, data_set):
243     ne = 0
244     for b in range(0, data_set.nb_batches):
245         input, target = data_set.get_batch(b)
246         output = model.forward(Variable(input))
247         wta_prediction = output.data.max(1)[1].view(-1)
248
249         for i in range(0, data_set.batch_size):
250             if wta_prediction[i] != target[i]:
251                 ne = ne + 1
252
253     return ne
254
255 ######################################################################
256
257 for arg in vars(args):
258     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
259
260 ######################################################################
261
262 def int_to_suffix(n):
263     if n > 1000000 and n%1000000 == 0:
264         return str(n//1000000) + 'M'
265     elif n > 1000 and n%1000 == 0:
266         return str(n//1000) + 'K'
267     else:
268         return str(n)
269
270 ######################################################################
271
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273     print('The number of samples must be a multiple of the batch size.')
274     raise
275
276 for problem_number in range(1, 24):
277
278     log_string('**** problem ' + str(problem_number) + ' ****')
279
280     if args.deep_model:
281         model = AfrozeDeepNet()
282     else:
283         model = AfrozeShallowNet()
284
285     if torch.cuda.is_available():
286         model.cuda()
287
288     model_filename = model.name + '_' + \
289                      str(problem_number) + '_' + \
290                      int_to_suffix(args.nb_train_samples) + '.param'
291
292     nb_parameters = 0
293     for p in model.parameters(): nb_parameters += p.numel()
294     log_string('nb_parameters {:d}'.format(nb_parameters))
295
296     need_to_train = False
297     try:
298         model.load_state_dict(torch.load(model_filename))
299         log_string('loaded_model ' + model_filename)
300     except:
301         need_to_train = True
302
303     if need_to_train:
304
305         log_string('training_model ' + model_filename)
306
307         t = time.time()
308
309         if args.compress_vignettes:
310             train_set = CompressedVignetteSet(problem_number,
311                                               args.nb_train_samples, args.batch_size,
312                                               cuda = torch.cuda.is_available())
313         else:
314             train_set = VignetteSet(problem_number,
315                                     args.nb_train_samples, args.batch_size,
316                                     cuda = torch.cuda.is_available())
317
318         log_string('data_generation {:0.2f} samples / s'.format(
319             train_set.nb_samples / (time.time() - t))
320         )
321
322         train_model(model, train_set)
323         torch.save(model.state_dict(), model_filename)
324         log_string('saved_model ' + model_filename)
325
326         nb_train_errors = nb_errors(model, train_set)
327
328         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
329             problem_number,
330             100 * nb_train_errors / train_set.nb_samples,
331             nb_train_errors,
332             train_set.nb_samples)
333         )
334
335     if need_to_train or args.test_loaded_models:
336
337         t = time.time()
338
339         if args.compress_vignettes:
340             test_set = CompressedVignetteSet(problem_number,
341                                              args.nb_test_samples, args.batch_size,
342                                              cuda = torch.cuda.is_available())
343         else:
344             test_set = VignetteSet(problem_number,
345                                    args.nb_test_samples, args.batch_size,
346                                    cuda = torch.cuda.is_available())
347
348         log_string('data_generation {:0.2f} samples / s'.format(
349             test_set.nb_samples / (time.time() - t))
350         )
351
352         nb_test_errors = nb_errors(model, test_set)
353
354         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
355             problem_number,
356             100 * nb_test_errors / test_set.nb_samples,
357             nb_test_errors,
358             test_set.nb_samples)
359         )
360
361 ######################################################################