3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
32 from colorama import Fore, Back, Style
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
44 from torch.nn import functional as fn
46 from torchvision import datasets, transforms, utils
52 ######################################################################
54 parser = argparse.ArgumentParser(
55 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56 formatter_class = argparse.ArgumentDefaultsHelpFormatter
59 parser.add_argument('--nb_train_samples',
60 type = int, default = 100000)
62 parser.add_argument('--nb_test_samples',
63 type = int, default = 10000)
65 parser.add_argument('--nb_validation_samples',
66 type = int, default = 10000)
68 parser.add_argument('--validation_error_threshold',
69 type = float, default = 0.0,
70 help = 'Early training termination criterion')
72 parser.add_argument('--nb_epochs',
73 type = int, default = 50)
75 parser.add_argument('--batch_size',
76 type = int, default = 100)
78 parser.add_argument('--log_file',
79 type = str, default = 'default.log')
81 parser.add_argument('--nb_exemplar_vignettes',
82 type = int, default = 32)
84 parser.add_argument('--compress_vignettes',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use lossless compression to reduce the memory footprint')
88 parser.add_argument('--save_test_mistakes',
89 type = distutils.util.strtobool, default = 'False')
91 parser.add_argument('--model',
92 type = str, default = 'deepnet',
93 help = 'What model to use')
95 parser.add_argument('--test_loaded_models',
96 type = distutils.util.strtobool, default = 'False',
97 help = 'Should we compute the test errors of loaded models')
99 parser.add_argument('--problems',
100 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101 help = 'What problems to process')
103 args = parser.parse_args()
105 ######################################################################
107 log_file = open(args.log_file, 'a')
109 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
113 last_tag_t = time.time()
115 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
117 # Log and prints the string, with a time stamp. Does not log the
120 def log_string(s, remark = ''):
121 global pred_log_t, last_tag_t
125 if pred_log_t is None:
128 elapsed = '+{:.02f}s'.format(t - pred_log_t)
132 if t > last_tag_t + 3600:
134 print(Fore.RED + time.ctime() + Style.RESET_ALL)
136 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
139 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
142 + s + Fore.CYAN + remark \
145 ######################################################################
147 def handler_sigint(signum, frame):
148 log_string('got sigint')
151 def handler_sigterm(signum, frame):
152 log_string('got sigterm')
155 signal.signal(signal.SIGINT, handler_sigint)
156 signal.signal(signal.SIGTERM, handler_sigterm)
158 ######################################################################
160 # Afroze's ShallowNet
163 # ----------------------
165 # -- conv(21x21 x 6) -> 108x108 6
166 # -- max(2x2) -> 54x54 6
167 # -- conv(19x19 x 16) -> 36x36 16
168 # -- max(2x2) -> 18x18 16
169 # -- conv(18x18 x 120) -> 1x1 120
170 # -- reshape -> 120 1
171 # -- full(120x84) -> 84 1
172 # -- full(84x2) -> 2 1
174 class AfrozeShallowNet(nn.Module):
178 super(AfrozeShallowNet, self).__init__()
179 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
180 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
181 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
182 self.fc1 = nn.Linear(120, 84)
183 self.fc2 = nn.Linear(84, 2)
185 def forward(self, x):
186 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
187 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
188 x = fn.relu(self.conv3(x))
190 x = fn.relu(self.fc1(x))
194 ######################################################################
198 class AfrozeDeepNet(nn.Module):
203 super(AfrozeDeepNet, self).__init__()
204 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
205 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
206 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
207 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
208 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
209 self.fc1 = nn.Linear(1536, 256)
210 self.fc2 = nn.Linear(256, 256)
211 self.fc3 = nn.Linear(256, 2)
213 def forward(self, x):
215 x = fn.max_pool2d(x, kernel_size=2)
219 x = fn.max_pool2d(x, kernel_size=2)
229 x = fn.max_pool2d(x, kernel_size=2)
244 ######################################################################
246 class DeepNet2(nn.Module):
250 super(DeepNet2, self).__init__()
251 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
252 self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
253 self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
254 self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
255 self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
256 self.fc1 = nn.Linear(4096, 512)
257 self.fc2 = nn.Linear(512, 512)
258 self.fc3 = nn.Linear(512, 2)
260 def forward(self, x):
262 x = fn.max_pool2d(x, kernel_size=2)
266 x = fn.max_pool2d(x, kernel_size=2)
276 x = fn.max_pool2d(x, kernel_size=2)
291 ######################################################################
293 class DeepNet3(nn.Module):
297 super(DeepNet3, self).__init__()
298 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
299 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
300 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
302 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
303 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
304 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305 self.fc1 = nn.Linear(2048, 256)
306 self.fc2 = nn.Linear(256, 256)
307 self.fc3 = nn.Linear(256, 2)
309 def forward(self, x):
311 x = fn.max_pool2d(x, kernel_size=2)
315 x = fn.max_pool2d(x, kernel_size=2)
325 x = fn.max_pool2d(x, kernel_size=2)
346 ######################################################################
348 def nb_errors(model, data_set, mistake_filename_pattern = None):
350 for b in range(0, data_set.nb_batches):
351 input, target = data_set.get_batch(b)
352 output = model.forward(Variable(input))
353 wta_prediction = output.data.max(1)[1].view(-1)
355 for i in range(0, data_set.batch_size):
356 if wta_prediction[i] != target[i]:
358 if mistake_filename_pattern is not None:
359 img = input[i].clone()
362 k = b * data_set.batch_size + i
363 filename = mistake_filename_pattern.format(k, target[i])
364 torchvision.utils.save_image(img, filename)
365 print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
368 ######################################################################
370 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
371 batch_size = args.batch_size
372 criterion = nn.CrossEntropyLoss()
374 if torch.cuda.is_available():
377 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
379 start_t = time.time()
381 for e in range(nb_epochs_done, args.nb_epochs):
383 for b in range(0, train_set.nb_batches):
384 input, target = train_set.get_batch(b)
385 output = model.forward(Variable(input))
386 loss = criterion(output, Variable(target))
387 acc_loss = acc_loss + loss.data[0]
391 dt = (time.time() - start_t) / (e + 1)
393 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
394 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
396 torch.save([ model.state_dict(), e + 1 ], model_filename)
398 if validation_set is not None:
399 nb_validation_errors = nb_errors(model, validation_set)
401 log_string('validation_error {:.02f}% {:d} {:d}'.format(
402 100 * nb_validation_errors / validation_set.nb_samples,
403 nb_validation_errors,
404 validation_set.nb_samples)
407 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
408 log_string('below validation_error_threshold')
413 ######################################################################
415 for arg in vars(args):
416 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
418 ######################################################################
420 def int_to_suffix(n):
421 if n >= 1000000 and n%1000000 == 0:
422 return str(n//1000000) + 'M'
423 elif n >= 1000 and n%1000 == 0:
424 return str(n//1000) + 'K'
428 class vignette_logger():
429 def __init__(self, delay_min = 60):
430 self.start_t = time.time()
431 self.last_t = self.start_t
432 self.delay_min = delay_min
434 def __call__(self, n, m):
436 if t > self.last_t + self.delay_min:
437 dt = (t - self.start_t) / m
438 log_string('sample_generation {:d} / {:d}'.format(
440 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
444 def save_examplar_vignettes(data_set, nb, name):
445 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
447 for k in range(0, nb):
448 b = n[k] // data_set.batch_size
449 m = n[k] % data_set.batch_size
450 i, t = data_set.get_batch(b)
454 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
455 patchwork[k].copy_(i)
457 torchvision.utils.save_image(patchwork, name)
459 ######################################################################
461 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
462 print('The number of samples must be a multiple of the batch size.')
465 if args.compress_vignettes:
466 log_string('using_compressed_vignettes')
467 VignetteSet = svrtset.CompressedVignetteSet
469 log_string('using_uncompressed_vignettes')
470 VignetteSet = svrtset.VignetteSet
472 ########################################
474 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
475 if args.model == m.name:
478 if model_class is None:
479 print('Unknown model ' + args.model)
482 log_string('using model class ' + m.name)
483 ########################################
485 for problem_number in map(int, args.problems.split(',')):
487 log_string('############### problem ' + str(problem_number) + ' ###############')
489 model = model_class()
491 if torch.cuda.is_available(): model.cuda()
493 model_filename = model.name + '_pb:' + \
494 str(problem_number) + '_ns:' + \
495 int_to_suffix(args.nb_train_samples) + '.state'
498 for p in model.parameters(): nb_parameters += p.numel()
499 log_string('nb_parameters {:d}'.format(nb_parameters))
501 ##################################################
502 # Tries to load the model
505 model_state_dict, nb_epochs_done = torch.load(model_filename)
506 model.load_state_dict(model_state_dict)
507 log_string('loaded_model ' + model_filename)
512 ##################################################
515 if nb_epochs_done < args.nb_epochs:
517 log_string('training_model ' + model_filename)
521 train_set = VignetteSet(problem_number,
522 args.nb_train_samples, args.batch_size,
523 cuda = torch.cuda.is_available(),
524 logger = vignette_logger())
526 log_string('data_generation {:0.2f} samples / s'.format(
527 train_set.nb_samples / (time.time() - t))
530 if args.nb_exemplar_vignettes > 0:
531 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
532 'examplar_{:d}.png'.format(problem_number))
534 if args.validation_error_threshold > 0.0:
535 validation_set = VignetteSet(problem_number,
536 args.nb_validation_samples, args.batch_size,
537 cuda = torch.cuda.is_available(),
538 logger = vignette_logger())
540 validation_set = None
542 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
543 log_string('saved_model ' + model_filename)
545 nb_train_errors = nb_errors(model, train_set)
547 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
549 100 * nb_train_errors / train_set.nb_samples,
551 train_set.nb_samples)
554 ##################################################
557 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
561 test_set = VignetteSet(problem_number,
562 args.nb_test_samples, args.batch_size,
563 cuda = torch.cuda.is_available())
565 nb_test_errors = nb_errors(model, test_set,
566 mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
568 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
570 100 * nb_test_errors / test_set.nb_samples,
575 ######################################################################