3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
32 from colorama import Fore, Back, Style
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
44 from torch.nn import functional as fn
46 from torchvision import datasets, transforms, utils
52 ######################################################################
54 parser = argparse.ArgumentParser(
55 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56 formatter_class = argparse.ArgumentDefaultsHelpFormatter
59 parser.add_argument('--nb_train_samples',
60 type = int, default = 100000)
62 parser.add_argument('--nb_test_samples',
63 type = int, default = 10000)
65 parser.add_argument('--nb_validation_samples',
66 type = int, default = 10000)
68 parser.add_argument('--validation_error_threshold',
69 type = float, default = 0.0,
70 help = 'Early training termination criterion')
72 parser.add_argument('--nb_epochs',
73 type = int, default = 50)
75 parser.add_argument('--batch_size',
76 type = int, default = 100)
78 parser.add_argument('--log_file',
79 type = str, default = 'default.log')
81 parser.add_argument('--nb_exemplar_vignettes',
82 type = int, default = 32)
84 parser.add_argument('--compress_vignettes',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use lossless compression to reduce the memory footprint')
88 parser.add_argument('--save_test_mistakes',
89 type = distutils.util.strtobool, default = 'False')
91 parser.add_argument('--model',
92 type = str, default = 'deepnet',
93 help = 'What model to use')
95 parser.add_argument('--test_loaded_models',
96 type = distutils.util.strtobool, default = 'False',
97 help = 'Should we compute the test errors of loaded models')
99 parser.add_argument('--problems',
100 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101 help = 'What problems to process')
103 args = parser.parse_args()
105 ######################################################################
107 log_file = open(args.log_file, 'a')
109 last_tag_t = time.time()
111 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
113 # Log and prints the string, with a time stamp. Does not log the
116 def log_string(s, remark = ''):
117 global pred_log_t, last_tag_t
121 if pred_log_t is None:
124 elapsed = '+{:.02f}s'.format(t - pred_log_t)
128 if t > last_tag_t + 3600:
130 print(Fore.RED + time.ctime() + Style.RESET_ALL)
132 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
135 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
138 + s + Fore.CYAN + remark \
141 ######################################################################
143 def handler_sigint(signum, frame):
144 log_string('got sigint')
147 def handler_sigterm(signum, frame):
148 log_string('got sigterm')
151 signal.signal(signal.SIGINT, handler_sigint)
152 signal.signal(signal.SIGTERM, handler_sigterm)
154 ######################################################################
156 # Afroze's ShallowNet
159 # ----------------------
161 # -- conv(21x21 x 6) -> 108x108 6
162 # -- max(2x2) -> 54x54 6
163 # -- conv(19x19 x 16) -> 36x36 16
164 # -- max(2x2) -> 18x18 16
165 # -- conv(18x18 x 120) -> 1x1 120
166 # -- reshape -> 120 1
167 # -- full(120x84) -> 84 1
168 # -- full(84x2) -> 2 1
170 class AfrozeShallowNet(nn.Module):
174 super(AfrozeShallowNet, self).__init__()
175 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
176 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
177 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
178 self.fc1 = nn.Linear(120, 84)
179 self.fc2 = nn.Linear(84, 2)
181 def forward(self, x):
182 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
183 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
184 x = fn.relu(self.conv3(x))
186 x = fn.relu(self.fc1(x))
190 ######################################################################
194 class AfrozeDeepNet(nn.Module):
199 super(AfrozeDeepNet, self).__init__()
200 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
201 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
202 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
203 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
204 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
205 self.fc1 = nn.Linear(1536, 256)
206 self.fc2 = nn.Linear(256, 256)
207 self.fc3 = nn.Linear(256, 2)
209 def forward(self, x):
211 x = fn.max_pool2d(x, kernel_size=2)
215 x = fn.max_pool2d(x, kernel_size=2)
225 x = fn.max_pool2d(x, kernel_size=2)
240 ######################################################################
242 class DeepNet2(nn.Module):
246 super(DeepNet2, self).__init__()
247 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
248 self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
249 self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
250 self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
251 self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
252 self.fc1 = nn.Linear(4096, 512)
253 self.fc2 = nn.Linear(512, 512)
254 self.fc3 = nn.Linear(512, 2)
256 def forward(self, x):
258 x = fn.max_pool2d(x, kernel_size=2)
262 x = fn.max_pool2d(x, kernel_size=2)
272 x = fn.max_pool2d(x, kernel_size=2)
287 ######################################################################
289 class DeepNet3(nn.Module):
293 super(DeepNet3, self).__init__()
294 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
295 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
296 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
299 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
300 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301 self.fc1 = nn.Linear(2048, 256)
302 self.fc2 = nn.Linear(256, 256)
303 self.fc3 = nn.Linear(256, 2)
305 def forward(self, x):
307 x = fn.max_pool2d(x, kernel_size=2)
311 x = fn.max_pool2d(x, kernel_size=2)
321 x = fn.max_pool2d(x, kernel_size=2)
342 ######################################################################
344 def nb_errors(model, data_set, mistake_filename_pattern = None):
346 for b in range(0, data_set.nb_batches):
347 input, target = data_set.get_batch(b)
348 output = model.forward(Variable(input))
349 wta_prediction = output.data.max(1)[1].view(-1)
351 for i in range(0, data_set.batch_size):
352 if wta_prediction[i] != target[i]:
354 if mistake_filename_pattern is not None:
355 img = input[i].clone()
358 torchvision.utils.save_image(img,
359 mistake_filename_pattern.format(b + i, target[i]))
363 ######################################################################
365 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
366 batch_size = args.batch_size
367 criterion = nn.CrossEntropyLoss()
369 if torch.cuda.is_available():
372 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
374 start_t = time.time()
376 for e in range(nb_epochs_done, args.nb_epochs):
378 for b in range(0, train_set.nb_batches):
379 input, target = train_set.get_batch(b)
380 output = model.forward(Variable(input))
381 loss = criterion(output, Variable(target))
382 acc_loss = acc_loss + loss.data[0]
386 dt = (time.time() - start_t) / (e + 1)
388 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
389 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
391 torch.save([ model.state_dict(), e + 1 ], model_filename)
393 if validation_set is not None:
394 nb_validation_errors = nb_errors(model, validation_set)
396 log_string('validation_error {:.02f}% {:d} {:d}'.format(
397 100 * nb_validation_errors / validation_set.nb_samples,
398 nb_validation_errors,
399 validation_set.nb_samples)
402 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
403 log_string('below validation_error_threshold')
408 ######################################################################
410 for arg in vars(args):
411 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
413 ######################################################################
415 def int_to_suffix(n):
416 if n >= 1000000 and n%1000000 == 0:
417 return str(n//1000000) + 'M'
418 elif n >= 1000 and n%1000 == 0:
419 return str(n//1000) + 'K'
423 class vignette_logger():
424 def __init__(self, delay_min = 60):
425 self.start_t = time.time()
426 self.last_t = self.start_t
427 self.delay_min = delay_min
429 def __call__(self, n, m):
431 if t > self.last_t + self.delay_min:
432 dt = (t - self.start_t) / m
433 log_string('sample_generation {:d} / {:d}'.format(
435 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
439 def save_examplar_vignettes(data_set, nb, name):
440 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
442 for k in range(0, nb):
443 b = n[k] // data_set.batch_size
444 m = n[k] % data_set.batch_size
445 i, t = data_set.get_batch(b)
449 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
450 patchwork[k].copy_(i)
452 torchvision.utils.save_image(patchwork, name)
454 ######################################################################
456 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
457 print('The number of samples must be a multiple of the batch size.')
460 log_string('############### start ###############')
462 if args.compress_vignettes:
463 log_string('using_compressed_vignettes')
464 VignetteSet = svrtset.CompressedVignetteSet
466 log_string('using_uncompressed_vignettes')
467 VignetteSet = svrtset.VignetteSet
469 ########################################
471 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
472 if args.model == m.name:
475 if model_class is None:
476 print('Unknown model ' + args.model)
479 log_string('using model class ' + m.name)
480 ########################################
482 for problem_number in map(int, args.problems.split(',')):
484 log_string('############### problem ' + str(problem_number) + ' ###############')
486 model = model_class()
488 if torch.cuda.is_available(): model.cuda()
490 model_filename = model.name + '_pb:' + \
491 str(problem_number) + '_ns:' + \
492 int_to_suffix(args.nb_train_samples) + '.state'
495 for p in model.parameters(): nb_parameters += p.numel()
496 log_string('nb_parameters {:d}'.format(nb_parameters))
498 ##################################################
499 # Tries to load the model
502 model_state_dict, nb_epochs_done = torch.load(model_filename)
503 model.load_state_dict(model_state_dict)
504 log_string('loaded_model ' + model_filename)
509 ##################################################
512 if nb_epochs_done < args.nb_epochs:
514 log_string('training_model ' + model_filename)
518 train_set = VignetteSet(problem_number,
519 args.nb_train_samples, args.batch_size,
520 cuda = torch.cuda.is_available(),
521 logger = vignette_logger())
523 log_string('data_generation {:0.2f} samples / s'.format(
524 train_set.nb_samples / (time.time() - t))
527 if args.nb_exemplar_vignettes > 0:
528 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
529 'examplar_{:d}.png'.format(problem_number))
531 if args.validation_error_threshold > 0.0:
532 validation_set = VignetteSet(problem_number,
533 args.nb_validation_samples, args.batch_size,
534 cuda = torch.cuda.is_available(),
535 logger = vignette_logger())
537 validation_set = None
539 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
540 log_string('saved_model ' + model_filename)
542 nb_train_errors = nb_errors(model, train_set)
544 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
546 100 * nb_train_errors / train_set.nb_samples,
548 train_set.nb_samples)
551 ##################################################
554 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
558 test_set = VignetteSet(problem_number,
559 args.nb_test_samples, args.batch_size,
560 cuda = torch.cuda.is_available())
562 nb_test_errors = nb_errors(model, test_set,
563 mistake_filename_pattern = 'mistake_{:d}_{:06d}.png')
565 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
567 100 * nb_test_errors / test_set.nb_samples,
572 ######################################################################