3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
32 from colorama import Fore, Back, Style
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
44 from torch.nn import functional as fn
46 from torchvision import datasets, transforms, utils
52 ######################################################################
54 parser = argparse.ArgumentParser(
55 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56 formatter_class = argparse.ArgumentDefaultsHelpFormatter
59 parser.add_argument('--nb_train_samples',
60 type = int, default = 100000)
62 parser.add_argument('--nb_test_samples',
63 type = int, default = 10000)
65 parser.add_argument('--nb_validation_samples',
66 type = int, default = 10000)
68 parser.add_argument('--validation_error_threshold',
69 type = float, default = 0.0,
70 help = 'Early training termination criterion')
72 parser.add_argument('--nb_epochs',
73 type = int, default = 50)
75 parser.add_argument('--batch_size',
76 type = int, default = 100)
78 parser.add_argument('--log_file',
79 type = str, default = 'default.log')
81 parser.add_argument('--nb_exemplar_vignettes',
82 type = int, default = 32)
84 parser.add_argument('--compress_vignettes',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use lossless compression to reduce the memory footprint')
88 parser.add_argument('--save_test_mistakes',
89 type = distutils.util.strtobool, default = 'False')
91 parser.add_argument('--model',
92 type = str, default = 'deepnet',
93 help = 'What model to use')
95 parser.add_argument('--test_loaded_models',
96 type = distutils.util.strtobool, default = 'False',
97 help = 'Should we compute the test errors of loaded models')
99 parser.add_argument('--problems',
100 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101 help = 'What problems to process')
103 args = parser.parse_args()
105 ######################################################################
107 log_file = open(args.log_file, 'a')
109 last_tag_t = time.time()
111 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
113 # Log and prints the string, with a time stamp. Does not log the
116 def log_string(s, remark = ''):
117 global pred_log_t, last_tag_t
121 if pred_log_t is None:
124 elapsed = '+{:.02f}s'.format(t - pred_log_t)
128 if t > last_tag_t + 3600:
130 print(Fore.RED + time.ctime() + Style.RESET_ALL)
132 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
135 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
138 + s + Fore.CYAN + remark \
141 ######################################################################
143 def handler_sigint(signum, frame):
144 log_string('got sigint')
147 def handler_sigterm(signum, frame):
148 log_string('got sigterm')
151 signal.signal(signal.SIGINT, handler_sigint)
152 signal.signal(signal.SIGTERM, handler_sigterm)
154 ######################################################################
156 # Afroze's ShallowNet
159 # ----------------------
161 # -- conv(21x21 x 6) -> 108x108 6
162 # -- max(2x2) -> 54x54 6
163 # -- conv(19x19 x 16) -> 36x36 16
164 # -- max(2x2) -> 18x18 16
165 # -- conv(18x18 x 120) -> 1x1 120
166 # -- reshape -> 120 1
167 # -- full(120x84) -> 84 1
168 # -- full(84x2) -> 2 1
170 class AfrozeShallowNet(nn.Module):
174 super(AfrozeShallowNet, self).__init__()
175 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
176 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
177 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
178 self.fc1 = nn.Linear(120, 84)
179 self.fc2 = nn.Linear(84, 2)
181 def forward(self, x):
182 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
183 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
184 x = fn.relu(self.conv3(x))
186 x = fn.relu(self.fc1(x))
190 ######################################################################
194 class AfrozeDeepNet(nn.Module):
199 super(AfrozeDeepNet, self).__init__()
200 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
201 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
202 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
203 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
204 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
205 self.fc1 = nn.Linear(1536, 256)
206 self.fc2 = nn.Linear(256, 256)
207 self.fc3 = nn.Linear(256, 2)
209 def forward(self, x):
211 x = fn.max_pool2d(x, kernel_size=2)
215 x = fn.max_pool2d(x, kernel_size=2)
225 x = fn.max_pool2d(x, kernel_size=2)
240 ######################################################################
242 class DeepNet2(nn.Module):
246 super(DeepNet2, self).__init__()
247 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
248 self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
249 self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
250 self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
251 self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
252 self.fc1 = nn.Linear(4096, 512)
253 self.fc2 = nn.Linear(512, 512)
254 self.fc3 = nn.Linear(512, 2)
256 def forward(self, x):
258 x = fn.max_pool2d(x, kernel_size=2)
262 x = fn.max_pool2d(x, kernel_size=2)
272 x = fn.max_pool2d(x, kernel_size=2)
287 ######################################################################
289 class DeepNet3(nn.Module):
293 super(DeepNet3, self).__init__()
294 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
295 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
296 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
299 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
300 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301 self.fc1 = nn.Linear(2048, 256)
302 self.fc2 = nn.Linear(256, 256)
303 self.fc3 = nn.Linear(256, 2)
305 def forward(self, x):
307 x = fn.max_pool2d(x, kernel_size=2)
311 x = fn.max_pool2d(x, kernel_size=2)
321 x = fn.max_pool2d(x, kernel_size=2)
342 ######################################################################
344 def nb_errors(model, data_set, mistake_filename_pattern = None):
346 for b in range(0, data_set.nb_batches):
347 input, target = data_set.get_batch(b)
348 output = model.forward(Variable(input))
349 wta_prediction = output.data.max(1)[1].view(-1)
351 for i in range(0, data_set.batch_size):
352 if wta_prediction[i] != target[i]:
354 if mistake_filename_pattern is not None:
355 img = input[i].clone()
358 k = b * data_set.batch_size + i
359 filename = mistake_filename_pattern.format(k, target[i])
360 torchvision.utils.save_image(img, filename)
361 print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
364 ######################################################################
366 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
367 batch_size = args.batch_size
368 criterion = nn.CrossEntropyLoss()
370 if torch.cuda.is_available():
373 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
375 start_t = time.time()
377 for e in range(nb_epochs_done, args.nb_epochs):
379 for b in range(0, train_set.nb_batches):
380 input, target = train_set.get_batch(b)
381 output = model.forward(Variable(input))
382 loss = criterion(output, Variable(target))
383 acc_loss = acc_loss + loss.data[0]
387 dt = (time.time() - start_t) / (e + 1)
389 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
390 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
392 torch.save([ model.state_dict(), e + 1 ], model_filename)
394 if validation_set is not None:
395 nb_validation_errors = nb_errors(model, validation_set)
397 log_string('validation_error {:.02f}% {:d} {:d}'.format(
398 100 * nb_validation_errors / validation_set.nb_samples,
399 nb_validation_errors,
400 validation_set.nb_samples)
403 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
404 log_string('below validation_error_threshold')
409 ######################################################################
411 for arg in vars(args):
412 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
414 ######################################################################
416 def int_to_suffix(n):
417 if n >= 1000000 and n%1000000 == 0:
418 return str(n//1000000) + 'M'
419 elif n >= 1000 and n%1000 == 0:
420 return str(n//1000) + 'K'
424 class vignette_logger():
425 def __init__(self, delay_min = 60):
426 self.start_t = time.time()
427 self.last_t = self.start_t
428 self.delay_min = delay_min
430 def __call__(self, n, m):
432 if t > self.last_t + self.delay_min:
433 dt = (t - self.start_t) / m
434 log_string('sample_generation {:d} / {:d}'.format(
436 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
440 def save_examplar_vignettes(data_set, nb, name):
441 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
443 for k in range(0, nb):
444 b = n[k] // data_set.batch_size
445 m = n[k] % data_set.batch_size
446 i, t = data_set.get_batch(b)
450 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
451 patchwork[k].copy_(i)
453 torchvision.utils.save_image(patchwork, name)
455 ######################################################################
457 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
458 print('The number of samples must be a multiple of the batch size.')
461 log_string('############### start ###############')
463 if args.compress_vignettes:
464 log_string('using_compressed_vignettes')
465 VignetteSet = svrtset.CompressedVignetteSet
467 log_string('using_uncompressed_vignettes')
468 VignetteSet = svrtset.VignetteSet
470 ########################################
472 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
473 if args.model == m.name:
476 if model_class is None:
477 print('Unknown model ' + args.model)
480 log_string('using model class ' + m.name)
481 ########################################
483 for problem_number in map(int, args.problems.split(',')):
485 log_string('############### problem ' + str(problem_number) + ' ###############')
487 model = model_class()
489 if torch.cuda.is_available(): model.cuda()
491 model_filename = model.name + '_pb:' + \
492 str(problem_number) + '_ns:' + \
493 int_to_suffix(args.nb_train_samples) + '.state'
496 for p in model.parameters(): nb_parameters += p.numel()
497 log_string('nb_parameters {:d}'.format(nb_parameters))
499 ##################################################
500 # Tries to load the model
503 model_state_dict, nb_epochs_done = torch.load(model_filename)
504 model.load_state_dict(model_state_dict)
505 log_string('loaded_model ' + model_filename)
510 ##################################################
513 if nb_epochs_done < args.nb_epochs:
515 log_string('training_model ' + model_filename)
519 train_set = VignetteSet(problem_number,
520 args.nb_train_samples, args.batch_size,
521 cuda = torch.cuda.is_available(),
522 logger = vignette_logger())
524 log_string('data_generation {:0.2f} samples / s'.format(
525 train_set.nb_samples / (time.time() - t))
528 if args.nb_exemplar_vignettes > 0:
529 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
530 'examplar_{:d}.png'.format(problem_number))
532 if args.validation_error_threshold > 0.0:
533 validation_set = VignetteSet(problem_number,
534 args.nb_validation_samples, args.batch_size,
535 cuda = torch.cuda.is_available(),
536 logger = vignette_logger())
538 validation_set = None
540 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
541 log_string('saved_model ' + model_filename)
543 nb_train_errors = nb_errors(model, train_set)
545 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
547 100 * nb_train_errors / train_set.nb_samples,
549 train_set.nb_samples)
552 ##################################################
555 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
559 test_set = VignetteSet(problem_number,
560 args.nb_test_samples, args.batch_size,
561 cuda = torch.cuda.is_available())
563 nb_test_errors = nb_errors(model, test_set,
564 mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
566 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
568 100 * nb_test_errors / test_set.nb_samples,
573 ######################################################################