3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
32 from colorama import Fore, Back, Style
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
44 from torch.nn import functional as fn
46 from torchvision import datasets, transforms, utils
52 ######################################################################
54 parser = argparse.ArgumentParser(
55 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56 formatter_class = argparse.ArgumentDefaultsHelpFormatter
59 parser.add_argument('--nb_train_samples',
60 type = int, default = 100000)
62 parser.add_argument('--nb_test_samples',
63 type = int, default = 10000)
65 parser.add_argument('--nb_validation_samples',
66 type = int, default = 10000)
68 parser.add_argument('--validation_error_threshold',
69 type = float, default = 0.0,
70 help = 'Early training termination criterion')
72 parser.add_argument('--nb_epochs',
73 type = int, default = 50)
75 parser.add_argument('--batch_size',
76 type = int, default = 100)
78 parser.add_argument('--log_file',
79 type = str, default = 'default.log')
81 parser.add_argument('--nb_exemplar_vignettes',
82 type = int, default = 32)
84 parser.add_argument('--compress_vignettes',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use lossless compression to reduce the memory footprint')
88 parser.add_argument('--model',
89 type = str, default = 'deepnet',
90 help = 'What model to use')
92 parser.add_argument('--test_loaded_models',
93 type = distutils.util.strtobool, default = 'False',
94 help = 'Should we compute the test errors of loaded models')
96 parser.add_argument('--problems',
97 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
98 help = 'What problems to process')
100 args = parser.parse_args()
102 ######################################################################
104 log_file = open(args.log_file, 'a')
106 last_tag_t = time.time()
108 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
110 # Log and prints the string, with a time stamp. Does not log the
113 def log_string(s, remark = ''):
114 global pred_log_t, last_tag_t
118 if pred_log_t is None:
121 elapsed = '+{:.02f}s'.format(t - pred_log_t)
125 if t > last_tag_t + 3600:
127 print(Fore.RED + time.ctime() + Style.RESET_ALL)
129 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
132 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
135 + s + Fore.CYAN + remark \
138 ######################################################################
140 def handler_sigint(signum, frame):
141 log_string('got sigint')
144 def handler_sigterm(signum, frame):
145 log_string('got sigterm')
148 signal.signal(signal.SIGINT, handler_sigint)
149 signal.signal(signal.SIGTERM, handler_sigterm)
151 ######################################################################
153 # Afroze's ShallowNet
156 # ----------------------
158 # -- conv(21x21 x 6) -> 108x108 6
159 # -- max(2x2) -> 54x54 6
160 # -- conv(19x19 x 16) -> 36x36 16
161 # -- max(2x2) -> 18x18 16
162 # -- conv(18x18 x 120) -> 1x1 120
163 # -- reshape -> 120 1
164 # -- full(120x84) -> 84 1
165 # -- full(84x2) -> 2 1
167 class AfrozeShallowNet(nn.Module):
171 super(AfrozeShallowNet, self).__init__()
172 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
173 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
174 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
175 self.fc1 = nn.Linear(120, 84)
176 self.fc2 = nn.Linear(84, 2)
178 def forward(self, x):
179 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
180 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
181 x = fn.relu(self.conv3(x))
183 x = fn.relu(self.fc1(x))
187 ######################################################################
191 class AfrozeDeepNet(nn.Module):
196 super(AfrozeDeepNet, self).__init__()
197 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
198 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
199 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
200 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
201 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
202 self.fc1 = nn.Linear(1536, 256)
203 self.fc2 = nn.Linear(256, 256)
204 self.fc3 = nn.Linear(256, 2)
206 def forward(self, x):
208 x = fn.max_pool2d(x, kernel_size=2)
212 x = fn.max_pool2d(x, kernel_size=2)
222 x = fn.max_pool2d(x, kernel_size=2)
237 ######################################################################
239 class DeepNet2(nn.Module):
243 super(DeepNet2, self).__init__()
244 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
245 self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
246 self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
247 self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
248 self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
249 self.fc1 = nn.Linear(4096, 512)
250 self.fc2 = nn.Linear(512, 512)
251 self.fc3 = nn.Linear(512, 2)
253 def forward(self, x):
255 x = fn.max_pool2d(x, kernel_size=2)
259 x = fn.max_pool2d(x, kernel_size=2)
269 x = fn.max_pool2d(x, kernel_size=2)
284 ######################################################################
286 class DeepNet3(nn.Module):
290 super(DeepNet3, self).__init__()
291 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
292 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
293 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
294 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
295 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
296 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298 self.fc1 = nn.Linear(2048, 256)
299 self.fc2 = nn.Linear(256, 256)
300 self.fc3 = nn.Linear(256, 2)
302 def forward(self, x):
304 x = fn.max_pool2d(x, kernel_size=2)
308 x = fn.max_pool2d(x, kernel_size=2)
318 x = fn.max_pool2d(x, kernel_size=2)
339 ######################################################################
341 def nb_errors(model, data_set):
343 for b in range(0, data_set.nb_batches):
344 input, target = data_set.get_batch(b)
345 output = model.forward(Variable(input))
346 wta_prediction = output.data.max(1)[1].view(-1)
348 for i in range(0, data_set.batch_size):
349 if wta_prediction[i] != target[i]:
354 ######################################################################
356 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
357 batch_size = args.batch_size
358 criterion = nn.CrossEntropyLoss()
360 if torch.cuda.is_available():
363 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
365 start_t = time.time()
367 for e in range(nb_epochs_done, args.nb_epochs):
369 for b in range(0, train_set.nb_batches):
370 input, target = train_set.get_batch(b)
371 output = model.forward(Variable(input))
372 loss = criterion(output, Variable(target))
373 acc_loss = acc_loss + loss.data[0]
377 dt = (time.time() - start_t) / (e + 1)
379 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
380 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
382 torch.save([ model.state_dict(), e + 1 ], model_filename)
384 if validation_set is not None:
385 nb_validation_errors = nb_errors(model, validation_set)
387 log_string('validation_error {:.02f}% {:d} {:d}'.format(
388 100 * nb_validation_errors / validation_set.nb_samples,
389 nb_validation_errors,
390 validation_set.nb_samples)
393 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
394 log_string('below validation_error_threshold')
399 ######################################################################
401 for arg in vars(args):
402 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
404 ######################################################################
406 def int_to_suffix(n):
407 if n >= 1000000 and n%1000000 == 0:
408 return str(n//1000000) + 'M'
409 elif n >= 1000 and n%1000 == 0:
410 return str(n//1000) + 'K'
414 class vignette_logger():
415 def __init__(self, delay_min = 60):
416 self.start_t = time.time()
417 self.last_t = self.start_t
418 self.delay_min = delay_min
420 def __call__(self, n, m):
422 if t > self.last_t + self.delay_min:
423 dt = (t - self.start_t) / m
424 log_string('sample_generation {:d} / {:d}'.format(
426 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
430 def save_examplar_vignettes(data_set, nb, name):
431 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
433 for k in range(0, nb):
434 b = n[k] // data_set.batch_size
435 m = n[k] % data_set.batch_size
436 i, t = data_set.get_batch(b)
440 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
441 patchwork[k].copy_(i)
443 torchvision.utils.save_image(patchwork, name)
445 ######################################################################
447 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
448 print('The number of samples must be a multiple of the batch size.')
451 log_string('############### start ###############')
453 if args.compress_vignettes:
454 log_string('using_compressed_vignettes')
455 VignetteSet = svrtset.CompressedVignetteSet
457 log_string('using_uncompressed_vignettes')
458 VignetteSet = svrtset.VignetteSet
460 ########################################
462 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
463 if args.model == m.name:
466 if model_class is None:
467 print('Unknown model ' + args.model)
470 log_string('using model class ' + m.name)
471 ########################################
473 for problem_number in map(int, args.problems.split(',')):
475 log_string('############### problem ' + str(problem_number) + ' ###############')
477 model = model_class()
479 if torch.cuda.is_available(): model.cuda()
481 model_filename = model.name + '_pb:' + \
482 str(problem_number) + '_ns:' + \
483 int_to_suffix(args.nb_train_samples) + '.state'
486 for p in model.parameters(): nb_parameters += p.numel()
487 log_string('nb_parameters {:d}'.format(nb_parameters))
489 ##################################################
490 # Tries to load the model
493 model_state_dict, nb_epochs_done = torch.load(model_filename)
494 model.load_state_dict(model_state_dict)
495 log_string('loaded_model ' + model_filename)
500 ##################################################
503 if nb_epochs_done < args.nb_epochs:
505 log_string('training_model ' + model_filename)
509 train_set = VignetteSet(problem_number,
510 args.nb_train_samples, args.batch_size,
511 cuda = torch.cuda.is_available(),
512 logger = vignette_logger())
514 log_string('data_generation {:0.2f} samples / s'.format(
515 train_set.nb_samples / (time.time() - t))
518 if args.nb_exemplar_vignettes > 0:
519 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
520 'examplar_{:d}.png'.format(problem_number))
522 if args.validation_error_threshold > 0.0:
523 validation_set = VignetteSet(problem_number,
524 args.nb_validation_samples, args.batch_size,
525 cuda = torch.cuda.is_available(),
526 logger = vignette_logger())
528 validation_set = None
530 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
531 log_string('saved_model ' + model_filename)
533 nb_train_errors = nb_errors(model, train_set)
535 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
537 100 * nb_train_errors / train_set.nb_samples,
539 train_set.nb_samples)
542 ##################################################
545 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
549 test_set = VignetteSet(problem_number,
550 args.nb_test_samples, args.batch_size,
551 cuda = torch.cuda.is_available())
553 nb_test_errors = nb_errors(model, test_set)
555 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
557 100 * nb_test_errors / test_set.nb_samples,
562 ######################################################################