3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
33 from colorama import Fore, Back, Style
35 Fore, Back, Style = '', '', ''
42 from torch import optim
43 from torch import multiprocessing
44 from torch import FloatTensor as Tensor
45 from torch.autograd import Variable
47 from torch.nn import functional as fn
49 from torchvision import datasets, transforms, utils
55 ######################################################################
57 parser = argparse.ArgumentParser(
58 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
59 formatter_class = argparse.ArgumentDefaultsHelpFormatter
62 parser.add_argument('--nb_train_samples',
63 type = int, default = 100000)
65 parser.add_argument('--nb_test_samples',
66 type = int, default = 10000)
68 parser.add_argument('--nb_validation_samples',
69 type = int, default = 10000)
71 parser.add_argument('--validation_error_threshold',
72 type = float, default = 0.0,
73 help = 'Early training termination criterion')
75 parser.add_argument('--nb_epochs',
76 type = int, default = 50)
78 parser.add_argument('--batch_size',
79 type = int, default = 100)
81 parser.add_argument('--log_file',
82 type = str, default = 'default.log')
84 parser.add_argument('--nb_exemplar_vignettes',
85 type = int, default = 32)
87 parser.add_argument('--compress_vignettes',
88 type = distutils.util.strtobool, default = 'True',
89 help = 'Use lossless compression to reduce the memory footprint')
91 parser.add_argument('--save_test_mistakes',
92 type = distutils.util.strtobool, default = 'False')
94 parser.add_argument('--model',
95 type = str, default = 'deepnet',
96 help = 'What model to use')
98 parser.add_argument('--test_loaded_models',
99 type = distutils.util.strtobool, default = 'False',
100 help = 'Should we compute the test errors of loaded models')
102 parser.add_argument('--problems',
103 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
104 help = 'What problems to process')
106 args = parser.parse_args()
108 ######################################################################
110 log_file = open(args.log_file, 'a')
112 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
116 last_tag_t = time.time()
118 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
120 # Log and prints the string, with a time stamp. Does not log the
123 def log_string(s, remark = ''):
124 global pred_log_t, last_tag_t
128 if pred_log_t is None:
131 elapsed = '+{:.02f}s'.format(t - pred_log_t)
135 if t > last_tag_t + 3600:
137 print(Fore.RED + time.ctime() + Style.RESET_ALL)
139 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
142 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
145 + s + Fore.CYAN + remark \
148 ######################################################################
150 def handler_sigint(signum, frame):
151 log_string('got sigint')
154 def handler_sigterm(signum, frame):
155 log_string('got sigterm')
158 signal.signal(signal.SIGINT, handler_sigint)
159 signal.signal(signal.SIGTERM, handler_sigterm)
161 ######################################################################
163 # Afroze's ShallowNet
166 # ----------------------
168 # -- conv(21x21 x 6) -> 108x108 6
169 # -- max(2x2) -> 54x54 6
170 # -- conv(19x19 x 16) -> 36x36 16
171 # -- max(2x2) -> 18x18 16
172 # -- conv(18x18 x 120) -> 1x1 120
173 # -- reshape -> 120 1
174 # -- full(120x84) -> 84 1
175 # -- full(84x2) -> 2 1
177 class AfrozeShallowNet(nn.Module):
181 super(AfrozeShallowNet, self).__init__()
182 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
183 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
184 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
185 self.fc1 = nn.Linear(120, 84)
186 self.fc2 = nn.Linear(84, 2)
188 def forward(self, x):
189 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
190 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
191 x = fn.relu(self.conv3(x))
193 x = fn.relu(self.fc1(x))
197 ######################################################################
201 class AfrozeDeepNet(nn.Module):
206 super(AfrozeDeepNet, self).__init__()
207 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
208 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
209 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
210 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
211 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
212 self.fc1 = nn.Linear(1536, 256)
213 self.fc2 = nn.Linear(256, 256)
214 self.fc3 = nn.Linear(256, 2)
216 def forward(self, x):
218 x = fn.max_pool2d(x, kernel_size=2)
222 x = fn.max_pool2d(x, kernel_size=2)
232 x = fn.max_pool2d(x, kernel_size=2)
247 ######################################################################
249 class DeepNet2(nn.Module):
253 super(DeepNet2, self).__init__()
254 self.nb_channels = 512
255 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
256 self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2)
257 self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
258 self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
259 self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
260 self.fc1 = nn.Linear(16 * self.nb_channels, 512)
261 self.fc2 = nn.Linear(512, 512)
262 self.fc3 = nn.Linear(512, 2)
264 def forward(self, x):
266 x = fn.max_pool2d(x, kernel_size=2)
270 x = fn.max_pool2d(x, kernel_size=2)
280 x = fn.max_pool2d(x, kernel_size=2)
283 x = x.view(-1, 16 * self.nb_channels)
295 ######################################################################
297 class DeepNet3(nn.Module):
301 super(DeepNet3, self).__init__()
302 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
303 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
304 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
306 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
307 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
308 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
309 self.fc1 = nn.Linear(2048, 256)
310 self.fc2 = nn.Linear(256, 256)
311 self.fc3 = nn.Linear(256, 2)
313 def forward(self, x):
315 x = fn.max_pool2d(x, kernel_size=2)
319 x = fn.max_pool2d(x, kernel_size=2)
329 x = fn.max_pool2d(x, kernel_size=2)
350 ######################################################################
352 def nb_errors(model, data_set, mistake_filename_pattern = None):
354 for b in range(0, data_set.nb_batches):
355 input, target = data_set.get_batch(b)
356 output = model.forward(Variable(input))
357 wta_prediction = output.data.max(1)[1].view(-1)
359 for i in range(0, data_set.batch_size):
360 if wta_prediction[i] != target[i]:
362 if mistake_filename_pattern is not None:
363 img = input[i].clone()
366 k = b * data_set.batch_size + i
367 filename = mistake_filename_pattern.format(k, target[i])
368 torchvision.utils.save_image(img, filename)
369 print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
372 ######################################################################
374 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
375 batch_size = args.batch_size
376 criterion = nn.CrossEntropyLoss()
378 if torch.cuda.is_available():
381 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
383 start_t = time.time()
385 for e in range(nb_epochs_done, args.nb_epochs):
387 for b in range(0, train_set.nb_batches):
388 input, target = train_set.get_batch(b)
389 output = model.forward(Variable(input))
390 loss = criterion(output, Variable(target))
391 acc_loss = acc_loss + loss.data[0]
395 dt = (time.time() - start_t) / (e + 1)
397 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
398 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
400 torch.save([ model.state_dict(), e + 1 ], model_filename)
402 if validation_set is not None:
403 nb_validation_errors = nb_errors(model, validation_set)
405 log_string('validation_error {:.02f}% {:d} {:d}'.format(
406 100 * nb_validation_errors / validation_set.nb_samples,
407 nb_validation_errors,
408 validation_set.nb_samples)
411 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
412 log_string('below validation_error_threshold')
417 ######################################################################
419 for arg in vars(args):
420 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
422 ######################################################################
424 def int_to_suffix(n):
425 if n >= 1000000 and n%1000000 == 0:
426 return str(n//1000000) + 'M'
427 elif n >= 1000 and n%1000 == 0:
428 return str(n//1000) + 'K'
432 class vignette_logger():
433 def __init__(self, delay_min = 60):
434 self.start_t = time.time()
435 self.last_t = self.start_t
436 self.delay_min = delay_min
438 def __call__(self, n, m):
440 if t > self.last_t + self.delay_min:
441 dt = (t - self.start_t) / m
442 log_string('sample_generation {:d} / {:d}'.format(
444 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
448 def save_examplar_vignettes(data_set, nb, name):
449 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
451 for k in range(0, nb):
452 b = n[k] // data_set.batch_size
453 m = n[k] % data_set.batch_size
454 i, t = data_set.get_batch(b)
458 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
459 patchwork[k].copy_(i)
461 torchvision.utils.save_image(patchwork, name)
463 ######################################################################
465 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
466 print('The number of samples must be a multiple of the batch size.')
469 if args.compress_vignettes:
470 log_string('using_compressed_vignettes')
471 VignetteSet = svrtset.CompressedVignetteSet
473 log_string('using_uncompressed_vignettes')
474 VignetteSet = svrtset.VignetteSet
476 ########################################
478 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
479 if args.model == m.name:
482 if model_class is None:
483 print('Unknown model ' + args.model)
486 log_string('using model class ' + m.name)
487 ########################################
489 for problem_number in map(int, args.problems.split(',')):
491 log_string('############### problem ' + str(problem_number) + ' ###############')
493 model = model_class()
495 if torch.cuda.is_available(): model.cuda()
497 model_filename = model.name + '_pb:' + \
498 str(problem_number) + '_ns:' + \
499 int_to_suffix(args.nb_train_samples) + '.state'
502 for p in model.parameters(): nb_parameters += p.numel()
503 log_string('nb_parameters {:d}'.format(nb_parameters))
505 ##################################################
506 # Tries to load the model
509 model_state_dict, nb_epochs_done = torch.load(model_filename)
510 model.load_state_dict(model_state_dict)
511 log_string('loaded_model ' + model_filename)
516 ##################################################
519 if nb_epochs_done < args.nb_epochs:
521 log_string('training_model ' + model_filename)
525 train_set = VignetteSet(problem_number,
526 args.nb_train_samples, args.batch_size,
527 cuda = torch.cuda.is_available(),
528 logger = vignette_logger())
530 log_string('data_generation {:0.2f} samples / s'.format(
531 train_set.nb_samples / (time.time() - t))
534 if args.nb_exemplar_vignettes > 0:
535 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
536 'examplar_{:d}.png'.format(problem_number))
538 if args.validation_error_threshold > 0.0:
539 validation_set = VignetteSet(problem_number,
540 args.nb_validation_samples, args.batch_size,
541 cuda = torch.cuda.is_available(),
542 logger = vignette_logger())
544 validation_set = None
546 train_model(model, model_filename,
547 train_set, validation_set,
548 nb_epochs_done = nb_epochs_done)
550 log_string('saved_model ' + model_filename)
552 nb_train_errors = nb_errors(model, train_set)
554 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
556 100 * nb_train_errors / train_set.nb_samples,
558 train_set.nb_samples)
561 ##################################################
564 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
568 test_set = VignetteSet(problem_number,
569 args.nb_test_samples, args.batch_size,
570 cuda = torch.cuda.is_available())
572 nb_test_errors = nb_errors(model, test_set,
573 mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
575 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
577 100 * nb_test_errors / test_set.nb_samples,
582 ######################################################################