The learning continues from where it stopped if a state is found.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35 import torchvision
36
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
41 from torch import nn
42 from torch.nn import functional as fn
43
44 from torchvision import datasets, transforms, utils
45
46 # SVRT
47
48 import svrtset
49
50 ######################################################################
51
52 parser = argparse.ArgumentParser(
53     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54     formatter_class = argparse.ArgumentDefaultsHelpFormatter
55 )
56
57 parser.add_argument('--nb_train_samples',
58                     type = int, default = 100000)
59
60 parser.add_argument('--nb_test_samples',
61                     type = int, default = 10000)
62
63 parser.add_argument('--nb_validation_samples',
64                     type = int, default = 10000)
65
66 parser.add_argument('--validation_error_threshold',
67                     type = float, default = 0.0,
68                     help = 'Early training termination criterion')
69
70 parser.add_argument('--nb_epochs',
71                     type = int, default = 50)
72
73 parser.add_argument('--batch_size',
74                     type = int, default = 100)
75
76 parser.add_argument('--log_file',
77                     type = str, default = 'default.log')
78
79 parser.add_argument('--nb_exemplar_vignettes',
80                     type = int, default = 32)
81
82 parser.add_argument('--compress_vignettes',
83                     type = distutils.util.strtobool, default = 'True',
84                     help = 'Use lossless compression to reduce the memory footprint')
85
86 parser.add_argument('--model',
87                     type = str, default = 'deepnet',
88                     help = 'What model to use')
89
90 parser.add_argument('--test_loaded_models',
91                     type = distutils.util.strtobool, default = 'False',
92                     help = 'Should we compute the test errors of loaded models')
93
94 parser.add_argument('--problems',
95                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96                     help = 'What problems to process')
97
98 args = parser.parse_args()
99
100 ######################################################################
101
102 log_file = open(args.log_file, 'a')
103 pred_log_t = None
104 last_tag_t = time.time()
105
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
107
108 # Log and prints the string, with a time stamp. Does not log the
109 # remark
110
111 def log_string(s, remark = ''):
112     global pred_log_t, last_tag_t
113
114     t = time.time()
115
116     if pred_log_t is None:
117         elapsed = 'start'
118     else:
119         elapsed = '+{:.02f}s'.format(t - pred_log_t)
120
121     pred_log_t = t
122
123     if t > last_tag_t + 3600:
124         last_tag_t = t
125         print(Fore.RED + time.ctime() + Style.RESET_ALL)
126
127     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
128     log_file.flush()
129
130     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
131
132 ######################################################################
133
134 # Afroze's ShallowNet
135
136 #                       map size   nb. maps
137 #                     ----------------------
138 #    input                128x128    1
139 # -- conv(21x21 x 6)   -> 108x108    6
140 # -- max(2x2)          -> 54x54      6
141 # -- conv(19x19 x 16)  -> 36x36      16
142 # -- max(2x2)          -> 18x18      16
143 # -- conv(18x18 x 120) -> 1x1        120
144 # -- reshape           -> 120        1
145 # -- full(120x84)      -> 84         1
146 # -- full(84x2)        -> 2          1
147
148 class AfrozeShallowNet(nn.Module):
149     name = 'shallownet'
150
151     def __init__(self):
152         super(AfrozeShallowNet, self).__init__()
153         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156         self.fc1 = nn.Linear(120, 84)
157         self.fc2 = nn.Linear(84, 2)
158
159     def forward(self, x):
160         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162         x = fn.relu(self.conv3(x))
163         x = x.view(-1, 120)
164         x = fn.relu(self.fc1(x))
165         x = self.fc2(x)
166         return x
167
168 ######################################################################
169
170 # Afroze's DeepNet
171
172 class AfrozeDeepNet(nn.Module):
173
174     name = 'deepnet'
175
176     def __init__(self):
177         super(AfrozeDeepNet, self).__init__()
178         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
179         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
180         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
183         self.fc1 = nn.Linear(1536, 256)
184         self.fc2 = nn.Linear(256, 256)
185         self.fc3 = nn.Linear(256, 2)
186
187     def forward(self, x):
188         x = self.conv1(x)
189         x = fn.max_pool2d(x, kernel_size=2)
190         x = fn.relu(x)
191
192         x = self.conv2(x)
193         x = fn.max_pool2d(x, kernel_size=2)
194         x = fn.relu(x)
195
196         x = self.conv3(x)
197         x = fn.relu(x)
198
199         x = self.conv4(x)
200         x = fn.relu(x)
201
202         x = self.conv5(x)
203         x = fn.max_pool2d(x, kernel_size=2)
204         x = fn.relu(x)
205
206         x = x.view(-1, 1536)
207
208         x = self.fc1(x)
209         x = fn.relu(x)
210
211         x = self.fc2(x)
212         x = fn.relu(x)
213
214         x = self.fc3(x)
215
216         return x
217
218 ######################################################################
219
220 class DeepNet2(nn.Module):
221     name = 'deepnet2'
222
223     def __init__(self):
224         super(DeepNet2, self).__init__()
225         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
226         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
227         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
228         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
229         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
230         self.fc1 = nn.Linear(2048, 512)
231         self.fc2 = nn.Linear(512, 512)
232         self.fc3 = nn.Linear(512, 2)
233
234     def forward(self, x):
235         x = self.conv1(x)
236         x = fn.max_pool2d(x, kernel_size=2)
237         x = fn.relu(x)
238
239         x = self.conv2(x)
240         x = fn.max_pool2d(x, kernel_size=2)
241         x = fn.relu(x)
242
243         x = self.conv3(x)
244         x = fn.relu(x)
245
246         x = self.conv4(x)
247         x = fn.relu(x)
248
249         x = self.conv5(x)
250         x = fn.max_pool2d(x, kernel_size=2)
251         x = fn.relu(x)
252
253         x = x.view(-1, 2048)
254
255         x = self.fc1(x)
256         x = fn.relu(x)
257
258         x = self.fc2(x)
259         x = fn.relu(x)
260
261         x = self.fc3(x)
262
263         return x
264
265 ######################################################################
266
267 def nb_errors(model, data_set):
268     ne = 0
269     for b in range(0, data_set.nb_batches):
270         input, target = data_set.get_batch(b)
271         output = model.forward(Variable(input))
272         wta_prediction = output.data.max(1)[1].view(-1)
273
274         for i in range(0, data_set.batch_size):
275             if wta_prediction[i] != target[i]:
276                 ne = ne + 1
277
278     return ne
279
280 ######################################################################
281
282 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
283     batch_size = args.batch_size
284     criterion = nn.CrossEntropyLoss()
285
286     if torch.cuda.is_available():
287         criterion.cuda()
288
289     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
290
291     start_t = time.time()
292
293     for e in range(nb_epochs_done, args.nb_epochs):
294         acc_loss = 0.0
295         for b in range(0, train_set.nb_batches):
296             input, target = train_set.get_batch(b)
297             output = model.forward(Variable(input))
298             loss = criterion(output, Variable(target))
299             acc_loss = acc_loss + loss.data[0]
300             model.zero_grad()
301             loss.backward()
302             optimizer.step()
303         dt = (time.time() - start_t) / (e + 1)
304
305         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
306                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
307
308         torch.save([ model.state_dict(), e + 1 ], model_filename)
309
310         if validation_set is not None:
311             nb_validation_errors = nb_errors(model, validation_set)
312
313             log_string('validation_error {:.02f}% {:d} {:d}'.format(
314                 100 * nb_validation_errors / validation_set.nb_samples,
315                 nb_validation_errors,
316                 validation_set.nb_samples)
317             )
318
319             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
320                 log_string('below validation_error_threshold')
321                 break
322
323     return model
324
325 ######################################################################
326
327 for arg in vars(args):
328     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
329
330 ######################################################################
331
332 def int_to_suffix(n):
333     if n >= 1000000 and n%1000000 == 0:
334         return str(n//1000000) + 'M'
335     elif n >= 1000 and n%1000 == 0:
336         return str(n//1000) + 'K'
337     else:
338         return str(n)
339
340 class vignette_logger():
341     def __init__(self, delay_min = 60):
342         self.start_t = time.time()
343         self.last_t = self.start_t
344         self.delay_min = delay_min
345
346     def __call__(self, n, m):
347         t = time.time()
348         if t > self.last_t + self.delay_min:
349             dt = (t - self.start_t) / m
350             log_string('sample_generation {:d} / {:d}'.format(
351                 m,
352                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
353             )
354             self.last_t = t
355
356 def save_examplar_vignettes(data_set, nb, name):
357     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
358
359     for k in range(0, nb):
360         b = n[k] // data_set.batch_size
361         m = n[k] % data_set.batch_size
362         i, t = data_set.get_batch(b)
363         i = i[m].float()
364         i.sub_(i.min())
365         i.div_(i.max())
366         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
367         patchwork[k].copy_(i)
368
369     torchvision.utils.save_image(patchwork, name)
370
371 ######################################################################
372
373 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
374     print('The number of samples must be a multiple of the batch size.')
375     raise
376
377 log_string('############### start ###############')
378
379 if args.compress_vignettes:
380     log_string('using_compressed_vignettes')
381     VignetteSet = svrtset.CompressedVignetteSet
382 else:
383     log_string('using_uncompressed_vignettes')
384     VignetteSet = svrtset.VignetteSet
385
386 ########################################
387 model_class = None
388 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
389     if args.model == m.name:
390         model_class = m
391         break
392 if model_class is None:
393     print('Unknown model ' + args.model)
394     raise
395
396 log_string('using model class ' + m.name)
397 ########################################
398
399 for problem_number in map(int, args.problems.split(',')):
400
401     log_string('############### problem ' + str(problem_number) + ' ###############')
402
403     model = model_class()
404
405     if torch.cuda.is_available(): model.cuda()
406
407     model_filename = model.name + '_pb:' + \
408                      str(problem_number) + '_ns:' + \
409                      int_to_suffix(args.nb_train_samples) + '.state'
410
411     nb_parameters = 0
412     for p in model.parameters(): nb_parameters += p.numel()
413     log_string('nb_parameters {:d}'.format(nb_parameters))
414
415     ##################################################
416     # Tries to load the model
417
418     need_to_train = False
419     try:
420         model_state_dict, nb_epochs_done = torch.load(model_filename)
421         model.load_state_dict(model_state_dict)
422         log_string('loaded_model ' + model_filename)
423     except:
424         nb_epochs_done = 0
425
426
427     ##################################################
428     # Train if necessary
429
430     if nb_epochs_done < args.nb_epochs:
431
432         log_string('training_model ' + model_filename)
433
434         t = time.time()
435
436         train_set = VignetteSet(problem_number,
437                                 args.nb_train_samples, args.batch_size,
438                                 cuda = torch.cuda.is_available(),
439                                 logger = vignette_logger())
440
441         log_string('data_generation {:0.2f} samples / s'.format(
442             train_set.nb_samples / (time.time() - t))
443         )
444
445         if args.nb_exemplar_vignettes > 0:
446             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
447                                     'examplar_{:d}.png'.format(problem_number))
448
449         if args.validation_error_threshold > 0.0:
450             validation_set = VignetteSet(problem_number,
451                                          args.nb_validation_samples, args.batch_size,
452                                          cuda = torch.cuda.is_available(),
453                                          logger = vignette_logger())
454         else:
455             validation_set = None
456
457         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
458         log_string('saved_model ' + model_filename)
459
460         nb_train_errors = nb_errors(model, train_set)
461
462         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
463             problem_number,
464             100 * nb_train_errors / train_set.nb_samples,
465             nb_train_errors,
466             train_set.nb_samples)
467         )
468
469     ##################################################
470     # Test if necessary
471
472     if need_to_train or args.test_loaded_models:
473
474         t = time.time()
475
476         test_set = VignetteSet(problem_number,
477                                args.nb_test_samples, args.batch_size,
478                                cuda = torch.cuda.is_available())
479
480         nb_test_errors = nb_errors(model, test_set)
481
482         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
483             problem_number,
484             100 * nb_test_errors / test_set.nb_samples,
485             nb_test_errors,
486             test_set.nb_samples)
487         )
488
489 ######################################################################