Deal with missing colorama.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 try:
33     from colorama import Fore, Back, Style
34 except ImportError:
35     Fore, Back, Style = '', '', ''
36
37 # Pytorch
38
39 import torch
40 import torchvision
41
42 from torch import optim
43 from torch import multiprocessing
44 from torch import FloatTensor as Tensor
45 from torch.autograd import Variable
46 from torch import nn
47 from torch.nn import functional as fn
48
49 from torchvision import datasets, transforms, utils
50
51 # SVRT
52
53 import svrtset
54
55 ######################################################################
56
57 parser = argparse.ArgumentParser(
58     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
59     formatter_class = argparse.ArgumentDefaultsHelpFormatter
60 )
61
62 parser.add_argument('--nb_train_samples',
63                     type = int, default = 100000)
64
65 parser.add_argument('--nb_test_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--nb_validation_samples',
69                     type = int, default = 10000)
70
71 parser.add_argument('--validation_error_threshold',
72                     type = float, default = 0.0,
73                     help = 'Early training termination criterion')
74
75 parser.add_argument('--nb_epochs',
76                     type = int, default = 50)
77
78 parser.add_argument('--batch_size',
79                     type = int, default = 100)
80
81 parser.add_argument('--log_file',
82                     type = str, default = 'default.log')
83
84 parser.add_argument('--nb_exemplar_vignettes',
85                     type = int, default = 32)
86
87 parser.add_argument('--compress_vignettes',
88                     type = distutils.util.strtobool, default = 'True',
89                     help = 'Use lossless compression to reduce the memory footprint')
90
91 parser.add_argument('--save_test_mistakes',
92                     type = distutils.util.strtobool, default = 'False')
93
94 parser.add_argument('--model',
95                     type = str, default = 'deepnet',
96                     help = 'What model to use')
97
98 parser.add_argument('--test_loaded_models',
99                     type = distutils.util.strtobool, default = 'False',
100                     help = 'Should we compute the test errors of loaded models')
101
102 parser.add_argument('--problems',
103                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
104                     help = 'What problems to process')
105
106 args = parser.parse_args()
107
108 ######################################################################
109
110 log_file = open(args.log_file, 'a')
111 log_file.write('\n')
112 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
113 log_file.write('\n')
114
115 pred_log_t = None
116 last_tag_t = time.time()
117
118 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
119
120 # Log and prints the string, with a time stamp. Does not log the
121 # remark
122
123 def log_string(s, remark = ''):
124     global pred_log_t, last_tag_t
125
126     t = time.time()
127
128     if pred_log_t is None:
129         elapsed = 'start'
130     else:
131         elapsed = '+{:.02f}s'.format(t - pred_log_t)
132
133     pred_log_t = t
134
135     if t > last_tag_t + 3600:
136         last_tag_t = t
137         print(Fore.RED + time.ctime() + Style.RESET_ALL)
138
139     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
140     log_file.flush()
141
142     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
143           + Style.RESET_ALL
144           + ' ' \
145           + s + Fore.CYAN + remark \
146           + Style.RESET_ALL)
147
148 ######################################################################
149
150 def handler_sigint(signum, frame):
151     log_string('got sigint')
152     exit(0)
153
154 def handler_sigterm(signum, frame):
155     log_string('got sigterm')
156     exit(0)
157
158 signal.signal(signal.SIGINT, handler_sigint)
159 signal.signal(signal.SIGTERM, handler_sigterm)
160
161 ######################################################################
162
163 # Afroze's ShallowNet
164
165 #                       map size   nb. maps
166 #                     ----------------------
167 #    input                128x128    1
168 # -- conv(21x21 x 6)   -> 108x108    6
169 # -- max(2x2)          -> 54x54      6
170 # -- conv(19x19 x 16)  -> 36x36      16
171 # -- max(2x2)          -> 18x18      16
172 # -- conv(18x18 x 120) -> 1x1        120
173 # -- reshape           -> 120        1
174 # -- full(120x84)      -> 84         1
175 # -- full(84x2)        -> 2          1
176
177 class AfrozeShallowNet(nn.Module):
178     name = 'shallownet'
179
180     def __init__(self):
181         super(AfrozeShallowNet, self).__init__()
182         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
183         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
184         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
185         self.fc1 = nn.Linear(120, 84)
186         self.fc2 = nn.Linear(84, 2)
187
188     def forward(self, x):
189         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
190         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
191         x = fn.relu(self.conv3(x))
192         x = x.view(-1, 120)
193         x = fn.relu(self.fc1(x))
194         x = self.fc2(x)
195         return x
196
197 ######################################################################
198
199 # Afroze's DeepNet
200
201 class AfrozeDeepNet(nn.Module):
202
203     name = 'deepnet'
204
205     def __init__(self):
206         super(AfrozeDeepNet, self).__init__()
207         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
208         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
209         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
210         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
211         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
212         self.fc1 = nn.Linear(1536, 256)
213         self.fc2 = nn.Linear(256, 256)
214         self.fc3 = nn.Linear(256, 2)
215
216     def forward(self, x):
217         x = self.conv1(x)
218         x = fn.max_pool2d(x, kernel_size=2)
219         x = fn.relu(x)
220
221         x = self.conv2(x)
222         x = fn.max_pool2d(x, kernel_size=2)
223         x = fn.relu(x)
224
225         x = self.conv3(x)
226         x = fn.relu(x)
227
228         x = self.conv4(x)
229         x = fn.relu(x)
230
231         x = self.conv5(x)
232         x = fn.max_pool2d(x, kernel_size=2)
233         x = fn.relu(x)
234
235         x = x.view(-1, 1536)
236
237         x = self.fc1(x)
238         x = fn.relu(x)
239
240         x = self.fc2(x)
241         x = fn.relu(x)
242
243         x = self.fc3(x)
244
245         return x
246
247 ######################################################################
248
249 class DeepNet2(nn.Module):
250     name = 'deepnet2'
251
252     def __init__(self):
253         super(DeepNet2, self).__init__()
254         self.nb_channels = 512
255         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
256         self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2)
257         self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
258         self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
259         self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
260         self.fc1 = nn.Linear(16 * self.nb_channels, 512)
261         self.fc2 = nn.Linear(512, 512)
262         self.fc3 = nn.Linear(512, 2)
263
264     def forward(self, x):
265         x = self.conv1(x)
266         x = fn.max_pool2d(x, kernel_size=2)
267         x = fn.relu(x)
268
269         x = self.conv2(x)
270         x = fn.max_pool2d(x, kernel_size=2)
271         x = fn.relu(x)
272
273         x = self.conv3(x)
274         x = fn.relu(x)
275
276         x = self.conv4(x)
277         x = fn.relu(x)
278
279         x = self.conv5(x)
280         x = fn.max_pool2d(x, kernel_size=2)
281         x = fn.relu(x)
282
283         x = x.view(-1, 16 * self.nb_channels)
284
285         x = self.fc1(x)
286         x = fn.relu(x)
287
288         x = self.fc2(x)
289         x = fn.relu(x)
290
291         x = self.fc3(x)
292
293         return x
294
295 ######################################################################
296
297 class DeepNet3(nn.Module):
298     name = 'deepnet3'
299
300     def __init__(self):
301         super(DeepNet3, self).__init__()
302         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
303         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
304         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
306         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
307         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
308         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
309         self.fc1 = nn.Linear(2048, 256)
310         self.fc2 = nn.Linear(256, 256)
311         self.fc3 = nn.Linear(256, 2)
312
313     def forward(self, x):
314         x = self.conv1(x)
315         x = fn.max_pool2d(x, kernel_size=2)
316         x = fn.relu(x)
317
318         x = self.conv2(x)
319         x = fn.max_pool2d(x, kernel_size=2)
320         x = fn.relu(x)
321
322         x = self.conv3(x)
323         x = fn.relu(x)
324
325         x = self.conv4(x)
326         x = fn.relu(x)
327
328         x = self.conv5(x)
329         x = fn.max_pool2d(x, kernel_size=2)
330         x = fn.relu(x)
331
332         x = self.conv6(x)
333         x = fn.relu(x)
334
335         x = self.conv7(x)
336         x = fn.relu(x)
337
338         x = x.view(-1, 2048)
339
340         x = self.fc1(x)
341         x = fn.relu(x)
342
343         x = self.fc2(x)
344         x = fn.relu(x)
345
346         x = self.fc3(x)
347
348         return x
349
350 ######################################################################
351
352 def nb_errors(model, data_set, mistake_filename_pattern = None):
353     ne = 0
354     for b in range(0, data_set.nb_batches):
355         input, target = data_set.get_batch(b)
356         output = model.forward(Variable(input))
357         wta_prediction = output.data.max(1)[1].view(-1)
358
359         for i in range(0, data_set.batch_size):
360             if wta_prediction[i] != target[i]:
361                 ne = ne + 1
362                 if mistake_filename_pattern is not None:
363                     img = input[i].clone()
364                     img.sub_(img.min())
365                     img.div_(img.max())
366                     k = b * data_set.batch_size + i
367                     filename = mistake_filename_pattern.format(k, target[i])
368                     torchvision.utils.save_image(img, filename)
369                     print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
370     return ne
371
372 ######################################################################
373
374 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
375     batch_size = args.batch_size
376     criterion = nn.CrossEntropyLoss()
377
378     if torch.cuda.is_available():
379         criterion.cuda()
380
381     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
382
383     start_t = time.time()
384
385     for e in range(nb_epochs_done, args.nb_epochs):
386         acc_loss = 0.0
387         for b in range(0, train_set.nb_batches):
388             input, target = train_set.get_batch(b)
389             output = model.forward(Variable(input))
390             loss = criterion(output, Variable(target))
391             acc_loss = acc_loss + loss.data[0]
392             model.zero_grad()
393             loss.backward()
394             optimizer.step()
395         dt = (time.time() - start_t) / (e + 1)
396
397         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
398                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
399
400         torch.save([ model.state_dict(), e + 1 ], model_filename)
401
402         if validation_set is not None:
403             nb_validation_errors = nb_errors(model, validation_set)
404
405             log_string('validation_error {:.02f}% {:d} {:d}'.format(
406                 100 * nb_validation_errors / validation_set.nb_samples,
407                 nb_validation_errors,
408                 validation_set.nb_samples)
409             )
410
411             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
412                 log_string('below validation_error_threshold')
413                 break
414
415     return model
416
417 ######################################################################
418
419 for arg in vars(args):
420     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
421
422 ######################################################################
423
424 def int_to_suffix(n):
425     if n >= 1000000 and n%1000000 == 0:
426         return str(n//1000000) + 'M'
427     elif n >= 1000 and n%1000 == 0:
428         return str(n//1000) + 'K'
429     else:
430         return str(n)
431
432 class vignette_logger():
433     def __init__(self, delay_min = 60):
434         self.start_t = time.time()
435         self.last_t = self.start_t
436         self.delay_min = delay_min
437
438     def __call__(self, n, m):
439         t = time.time()
440         if t > self.last_t + self.delay_min:
441             dt = (t - self.start_t) / m
442             log_string('sample_generation {:d} / {:d}'.format(
443                 m,
444                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
445             )
446             self.last_t = t
447
448 def save_examplar_vignettes(data_set, nb, name):
449     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
450
451     for k in range(0, nb):
452         b = n[k] // data_set.batch_size
453         m = n[k] % data_set.batch_size
454         i, t = data_set.get_batch(b)
455         i = i[m].float()
456         i.sub_(i.min())
457         i.div_(i.max())
458         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
459         patchwork[k].copy_(i)
460
461     torchvision.utils.save_image(patchwork, name)
462
463 ######################################################################
464
465 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
466     print('The number of samples must be a multiple of the batch size.')
467     raise
468
469 if args.compress_vignettes:
470     log_string('using_compressed_vignettes')
471     VignetteSet = svrtset.CompressedVignetteSet
472 else:
473     log_string('using_uncompressed_vignettes')
474     VignetteSet = svrtset.VignetteSet
475
476 ########################################
477 model_class = None
478 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
479     if args.model == m.name:
480         model_class = m
481         break
482 if model_class is None:
483     print('Unknown model ' + args.model)
484     raise
485
486 log_string('using model class ' + m.name)
487 ########################################
488
489 for problem_number in map(int, args.problems.split(',')):
490
491     log_string('############### problem ' + str(problem_number) + ' ###############')
492
493     model = model_class()
494
495     if torch.cuda.is_available(): model.cuda()
496
497     model_filename = model.name + '_pb:' + \
498                      str(problem_number) + '_ns:' + \
499                      int_to_suffix(args.nb_train_samples) + '.state'
500
501     nb_parameters = 0
502     for p in model.parameters(): nb_parameters += p.numel()
503     log_string('nb_parameters {:d}'.format(nb_parameters))
504
505     ##################################################
506     # Tries to load the model
507
508     try:
509         model_state_dict, nb_epochs_done = torch.load(model_filename)
510         model.load_state_dict(model_state_dict)
511         log_string('loaded_model ' + model_filename)
512     except:
513         nb_epochs_done = 0
514
515
516     ##################################################
517     # Train if necessary
518
519     if nb_epochs_done < args.nb_epochs:
520
521         log_string('training_model ' + model_filename)
522
523         t = time.time()
524
525         train_set = VignetteSet(problem_number,
526                                 args.nb_train_samples, args.batch_size,
527                                 cuda = torch.cuda.is_available(),
528                                 logger = vignette_logger())
529
530         log_string('data_generation {:0.2f} samples / s'.format(
531             train_set.nb_samples / (time.time() - t))
532         )
533
534         if args.nb_exemplar_vignettes > 0:
535             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
536                                     'examplar_{:d}.png'.format(problem_number))
537
538         if args.validation_error_threshold > 0.0:
539             validation_set = VignetteSet(problem_number,
540                                          args.nb_validation_samples, args.batch_size,
541                                          cuda = torch.cuda.is_available(),
542                                          logger = vignette_logger())
543         else:
544             validation_set = None
545
546         train_model(model, model_filename,
547                     train_set, validation_set,
548                     nb_epochs_done = nb_epochs_done)
549
550         log_string('saved_model ' + model_filename)
551
552         nb_train_errors = nb_errors(model, train_set)
553
554         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
555             problem_number,
556             100 * nb_train_errors / train_set.nb_samples,
557             nb_train_errors,
558             train_set.nb_samples)
559         )
560
561     ##################################################
562     # Test if necessary
563
564     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
565
566         t = time.time()
567
568         test_set = VignetteSet(problem_number,
569                                args.nb_test_samples, args.batch_size,
570                                cuda = torch.cuda.is_available())
571
572         nb_test_errors = nb_errors(model, test_set,
573                                    mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
574
575         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
576             problem_number,
577             100 * nb_test_errors / test_set.nb_samples,
578             nb_test_errors,
579             test_set.nb_samples)
580         )
581
582 ######################################################################