ade87ceea78dba435a3f9982e2e55f1bb719357e
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 from colorama import Fore, Back, Style
33
34 # Pytorch
35
36 import torch
37 import torchvision
38
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
43 from torch import nn
44 from torch.nn import functional as fn
45
46 from torchvision import datasets, transforms, utils
47
48 # SVRT
49
50 import svrtset
51
52 ######################################################################
53
54 parser = argparse.ArgumentParser(
55     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56     formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 )
58
59 parser.add_argument('--nb_train_samples',
60                     type = int, default = 100000)
61
62 parser.add_argument('--nb_test_samples',
63                     type = int, default = 10000)
64
65 parser.add_argument('--nb_validation_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--validation_error_threshold',
69                     type = float, default = 0.0,
70                     help = 'Early training termination criterion')
71
72 parser.add_argument('--nb_epochs',
73                     type = int, default = 50)
74
75 parser.add_argument('--batch_size',
76                     type = int, default = 100)
77
78 parser.add_argument('--log_file',
79                     type = str, default = 'default.log')
80
81 parser.add_argument('--nb_exemplar_vignettes',
82                     type = int, default = 32)
83
84 parser.add_argument('--compress_vignettes',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use lossless compression to reduce the memory footprint')
87
88 parser.add_argument('--model',
89                     type = str, default = 'deepnet',
90                     help = 'What model to use')
91
92 parser.add_argument('--test_loaded_models',
93                     type = distutils.util.strtobool, default = 'False',
94                     help = 'Should we compute the test errors of loaded models')
95
96 parser.add_argument('--problems',
97                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
98                     help = 'What problems to process')
99
100 args = parser.parse_args()
101
102 ######################################################################
103
104 log_file = open(args.log_file, 'a')
105 pred_log_t = None
106 last_tag_t = time.time()
107
108 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
109
110 # Log and prints the string, with a time stamp. Does not log the
111 # remark
112
113 def log_string(s, remark = ''):
114     global pred_log_t, last_tag_t
115
116     t = time.time()
117
118     if pred_log_t is None:
119         elapsed = 'start'
120     else:
121         elapsed = '+{:.02f}s'.format(t - pred_log_t)
122
123     pred_log_t = t
124
125     if t > last_tag_t + 3600:
126         last_tag_t = t
127         print(Fore.RED + time.ctime() + Style.RESET_ALL)
128
129     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
130     log_file.flush()
131
132     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
133           + Style.RESET_ALL
134           + ' ' \
135           + s + Fore.CYAN + remark \
136           + Style.RESET_ALL)
137
138 ######################################################################
139
140 def handler_sigint(signum, frame):
141     log_string('got sigint')
142     exit(0)
143
144 def handler_sigterm(signum, frame):
145     log_string('got sigterm')
146     exit(0)
147
148 signal.signal(signal.SIGINT, handler_sigint)
149 signal.signal(signal.SIGTERM, handler_sigterm)
150
151 ######################################################################
152
153 # Afroze's ShallowNet
154
155 #                       map size   nb. maps
156 #                     ----------------------
157 #    input                128x128    1
158 # -- conv(21x21 x 6)   -> 108x108    6
159 # -- max(2x2)          -> 54x54      6
160 # -- conv(19x19 x 16)  -> 36x36      16
161 # -- max(2x2)          -> 18x18      16
162 # -- conv(18x18 x 120) -> 1x1        120
163 # -- reshape           -> 120        1
164 # -- full(120x84)      -> 84         1
165 # -- full(84x2)        -> 2          1
166
167 class AfrozeShallowNet(nn.Module):
168     name = 'shallownet'
169
170     def __init__(self):
171         super(AfrozeShallowNet, self).__init__()
172         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
173         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
174         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
175         self.fc1 = nn.Linear(120, 84)
176         self.fc2 = nn.Linear(84, 2)
177
178     def forward(self, x):
179         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
180         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
181         x = fn.relu(self.conv3(x))
182         x = x.view(-1, 120)
183         x = fn.relu(self.fc1(x))
184         x = self.fc2(x)
185         return x
186
187 ######################################################################
188
189 # Afroze's DeepNet
190
191 class AfrozeDeepNet(nn.Module):
192
193     name = 'deepnet'
194
195     def __init__(self):
196         super(AfrozeDeepNet, self).__init__()
197         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
198         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
199         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
200         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
201         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
202         self.fc1 = nn.Linear(1536, 256)
203         self.fc2 = nn.Linear(256, 256)
204         self.fc3 = nn.Linear(256, 2)
205
206     def forward(self, x):
207         x = self.conv1(x)
208         x = fn.max_pool2d(x, kernel_size=2)
209         x = fn.relu(x)
210
211         x = self.conv2(x)
212         x = fn.max_pool2d(x, kernel_size=2)
213         x = fn.relu(x)
214
215         x = self.conv3(x)
216         x = fn.relu(x)
217
218         x = self.conv4(x)
219         x = fn.relu(x)
220
221         x = self.conv5(x)
222         x = fn.max_pool2d(x, kernel_size=2)
223         x = fn.relu(x)
224
225         x = x.view(-1, 1536)
226
227         x = self.fc1(x)
228         x = fn.relu(x)
229
230         x = self.fc2(x)
231         x = fn.relu(x)
232
233         x = self.fc3(x)
234
235         return x
236
237 ######################################################################
238
239 class DeepNet2(nn.Module):
240     name = 'deepnet2'
241
242     def __init__(self):
243         super(DeepNet2, self).__init__()
244         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
245         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
246         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
247         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
248         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
249         self.fc1 = nn.Linear(4096, 512)
250         self.fc2 = nn.Linear(512, 512)
251         self.fc3 = nn.Linear(512, 2)
252
253     def forward(self, x):
254         x = self.conv1(x)
255         x = fn.max_pool2d(x, kernel_size=2)
256         x = fn.relu(x)
257
258         x = self.conv2(x)
259         x = fn.max_pool2d(x, kernel_size=2)
260         x = fn.relu(x)
261
262         x = self.conv3(x)
263         x = fn.relu(x)
264
265         x = self.conv4(x)
266         x = fn.relu(x)
267
268         x = self.conv5(x)
269         x = fn.max_pool2d(x, kernel_size=2)
270         x = fn.relu(x)
271
272         x = x.view(-1, 4096)
273
274         x = self.fc1(x)
275         x = fn.relu(x)
276
277         x = self.fc2(x)
278         x = fn.relu(x)
279
280         x = self.fc3(x)
281
282         return x
283
284 ######################################################################
285
286 class DeepNet3(nn.Module):
287     name = 'deepnet3'
288
289     def __init__(self):
290         super(DeepNet3, self).__init__()
291         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
292         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
293         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
294         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
295         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
296         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298         self.fc1 = nn.Linear(2048, 256)
299         self.fc2 = nn.Linear(256, 256)
300         self.fc3 = nn.Linear(256, 2)
301
302     def forward(self, x):
303         x = self.conv1(x)
304         x = fn.max_pool2d(x, kernel_size=2)
305         x = fn.relu(x)
306
307         x = self.conv2(x)
308         x = fn.max_pool2d(x, kernel_size=2)
309         x = fn.relu(x)
310
311         x = self.conv3(x)
312         x = fn.relu(x)
313
314         x = self.conv4(x)
315         x = fn.relu(x)
316
317         x = self.conv5(x)
318         x = fn.max_pool2d(x, kernel_size=2)
319         x = fn.relu(x)
320
321         x = self.conv6(x)
322         x = fn.relu(x)
323
324         x = self.conv7(x)
325         x = fn.relu(x)
326
327         x = x.view(-1, 2048)
328
329         x = self.fc1(x)
330         x = fn.relu(x)
331
332         x = self.fc2(x)
333         x = fn.relu(x)
334
335         x = self.fc3(x)
336
337         return x
338
339 ######################################################################
340
341 def nb_errors(model, data_set):
342     ne = 0
343     for b in range(0, data_set.nb_batches):
344         input, target = data_set.get_batch(b)
345         output = model.forward(Variable(input))
346         wta_prediction = output.data.max(1)[1].view(-1)
347
348         for i in range(0, data_set.batch_size):
349             if wta_prediction[i] != target[i]:
350                 ne = ne + 1
351
352     return ne
353
354 ######################################################################
355
356 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
357     batch_size = args.batch_size
358     criterion = nn.CrossEntropyLoss()
359
360     if torch.cuda.is_available():
361         criterion.cuda()
362
363     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
364
365     start_t = time.time()
366
367     for e in range(nb_epochs_done, args.nb_epochs):
368         acc_loss = 0.0
369         for b in range(0, train_set.nb_batches):
370             input, target = train_set.get_batch(b)
371             output = model.forward(Variable(input))
372             loss = criterion(output, Variable(target))
373             acc_loss = acc_loss + loss.data[0]
374             model.zero_grad()
375             loss.backward()
376             optimizer.step()
377         dt = (time.time() - start_t) / (e + 1)
378
379         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
380                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
381
382         torch.save([ model.state_dict(), e + 1 ], model_filename)
383
384         if validation_set is not None:
385             nb_validation_errors = nb_errors(model, validation_set)
386
387             log_string('validation_error {:.02f}% {:d} {:d}'.format(
388                 100 * nb_validation_errors / validation_set.nb_samples,
389                 nb_validation_errors,
390                 validation_set.nb_samples)
391             )
392
393             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
394                 log_string('below validation_error_threshold')
395                 break
396
397     return model
398
399 ######################################################################
400
401 for arg in vars(args):
402     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
403
404 ######################################################################
405
406 def int_to_suffix(n):
407     if n >= 1000000 and n%1000000 == 0:
408         return str(n//1000000) + 'M'
409     elif n >= 1000 and n%1000 == 0:
410         return str(n//1000) + 'K'
411     else:
412         return str(n)
413
414 class vignette_logger():
415     def __init__(self, delay_min = 60):
416         self.start_t = time.time()
417         self.last_t = self.start_t
418         self.delay_min = delay_min
419
420     def __call__(self, n, m):
421         t = time.time()
422         if t > self.last_t + self.delay_min:
423             dt = (t - self.start_t) / m
424             log_string('sample_generation {:d} / {:d}'.format(
425                 m,
426                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
427             )
428             self.last_t = t
429
430 def save_examplar_vignettes(data_set, nb, name):
431     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
432
433     for k in range(0, nb):
434         b = n[k] // data_set.batch_size
435         m = n[k] % data_set.batch_size
436         i, t = data_set.get_batch(b)
437         i = i[m].float()
438         i.sub_(i.min())
439         i.div_(i.max())
440         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
441         patchwork[k].copy_(i)
442
443     torchvision.utils.save_image(patchwork, name)
444
445 ######################################################################
446
447 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
448     print('The number of samples must be a multiple of the batch size.')
449     raise
450
451 log_string('############### start ###############')
452
453 if args.compress_vignettes:
454     log_string('using_compressed_vignettes')
455     VignetteSet = svrtset.CompressedVignetteSet
456 else:
457     log_string('using_uncompressed_vignettes')
458     VignetteSet = svrtset.VignetteSet
459
460 ########################################
461 model_class = None
462 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
463     if args.model == m.name:
464         model_class = m
465         break
466 if model_class is None:
467     print('Unknown model ' + args.model)
468     raise
469
470 log_string('using model class ' + m.name)
471 ########################################
472
473 for problem_number in map(int, args.problems.split(',')):
474
475     log_string('############### problem ' + str(problem_number) + ' ###############')
476
477     model = model_class()
478
479     if torch.cuda.is_available(): model.cuda()
480
481     model_filename = model.name + '_pb:' + \
482                      str(problem_number) + '_ns:' + \
483                      int_to_suffix(args.nb_train_samples) + '.state'
484
485     nb_parameters = 0
486     for p in model.parameters(): nb_parameters += p.numel()
487     log_string('nb_parameters {:d}'.format(nb_parameters))
488
489     ##################################################
490     # Tries to load the model
491
492     try:
493         model_state_dict, nb_epochs_done = torch.load(model_filename)
494         model.load_state_dict(model_state_dict)
495         log_string('loaded_model ' + model_filename)
496     except:
497         nb_epochs_done = 0
498
499
500     ##################################################
501     # Train if necessary
502
503     if nb_epochs_done < args.nb_epochs:
504
505         log_string('training_model ' + model_filename)
506
507         t = time.time()
508
509         train_set = VignetteSet(problem_number,
510                                 args.nb_train_samples, args.batch_size,
511                                 cuda = torch.cuda.is_available(),
512                                 logger = vignette_logger())
513
514         log_string('data_generation {:0.2f} samples / s'.format(
515             train_set.nb_samples / (time.time() - t))
516         )
517
518         if args.nb_exemplar_vignettes > 0:
519             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
520                                     'examplar_{:d}.png'.format(problem_number))
521
522         if args.validation_error_threshold > 0.0:
523             validation_set = VignetteSet(problem_number,
524                                          args.nb_validation_samples, args.batch_size,
525                                          cuda = torch.cuda.is_available(),
526                                          logger = vignette_logger())
527         else:
528             validation_set = None
529
530         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
531         log_string('saved_model ' + model_filename)
532
533         nb_train_errors = nb_errors(model, train_set)
534
535         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
536             problem_number,
537             100 * nb_train_errors / train_set.nb_samples,
538             nb_train_errors,
539             train_set.nb_samples)
540         )
541
542     ##################################################
543     # Test if necessary
544
545     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
546
547         t = time.time()
548
549         test_set = VignetteSet(problem_number,
550                                args.nb_test_samples, args.batch_size,
551                                cuda = torch.cuda.is_available())
552
553         nb_test_errors = nb_errors(model, test_set)
554
555         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
556             problem_number,
557             100 * nb_test_errors / test_set.nb_samples,
558             nb_test_errors,
559             test_set.nb_samples)
560         )
561
562 ######################################################################