3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
42 from torch.nn import functional as fn
44 from torchvision import datasets, transforms, utils
50 ######################################################################
52 parser = argparse.ArgumentParser(
53 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54 formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 parser.add_argument('--nb_train_samples',
58 type = int, default = 100000)
60 parser.add_argument('--nb_test_samples',
61 type = int, default = 10000)
63 parser.add_argument('--nb_validation_samples',
64 type = int, default = 10000)
66 parser.add_argument('--validation_error_threshold',
67 type = float, default = 0.0,
68 help = 'Early training termination criterion')
70 parser.add_argument('--nb_epochs',
71 type = int, default = 50)
73 parser.add_argument('--batch_size',
74 type = int, default = 100)
76 parser.add_argument('--log_file',
77 type = str, default = 'default.log')
79 parser.add_argument('--nb_exemplar_vignettes',
80 type = int, default = 32)
82 parser.add_argument('--compress_vignettes',
83 type = distutils.util.strtobool, default = 'True',
84 help = 'Use lossless compression to reduce the memory footprint')
86 parser.add_argument('--model',
87 type = str, default = 'deepnet',
88 help = 'What model to use')
90 parser.add_argument('--test_loaded_models',
91 type = distutils.util.strtobool, default = 'False',
92 help = 'Should we compute the test errors of loaded models')
94 parser.add_argument('--problems',
95 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96 help = 'What problems to process')
98 args = parser.parse_args()
100 ######################################################################
102 log_file = open(args.log_file, 'a')
104 last_tag_t = time.time()
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
108 # Log and prints the string, with a time stamp. Does not log the
111 def log_string(s, remark = ''):
112 global pred_log_t, last_tag_t
116 if pred_log_t is None:
119 elapsed = '+{:.02f}s'.format(t - pred_log_t)
123 if t > last_tag_t + 3600:
125 print(Fore.RED + time.ctime() + Style.RESET_ALL)
127 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
130 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
132 ######################################################################
134 # Afroze's ShallowNet
137 # ----------------------
139 # -- conv(21x21 x 6) -> 108x108 6
140 # -- max(2x2) -> 54x54 6
141 # -- conv(19x19 x 16) -> 36x36 16
142 # -- max(2x2) -> 18x18 16
143 # -- conv(18x18 x 120) -> 1x1 120
144 # -- reshape -> 120 1
145 # -- full(120x84) -> 84 1
146 # -- full(84x2) -> 2 1
148 class AfrozeShallowNet(nn.Module):
152 super(AfrozeShallowNet, self).__init__()
153 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156 self.fc1 = nn.Linear(120, 84)
157 self.fc2 = nn.Linear(84, 2)
159 def forward(self, x):
160 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162 x = fn.relu(self.conv3(x))
164 x = fn.relu(self.fc1(x))
168 ######################################################################
172 class AfrozeDeepNet(nn.Module):
177 super(AfrozeDeepNet, self).__init__()
178 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
179 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
180 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
183 self.fc1 = nn.Linear(1536, 256)
184 self.fc2 = nn.Linear(256, 256)
185 self.fc3 = nn.Linear(256, 2)
187 def forward(self, x):
189 x = fn.max_pool2d(x, kernel_size=2)
193 x = fn.max_pool2d(x, kernel_size=2)
203 x = fn.max_pool2d(x, kernel_size=2)
218 ######################################################################
220 class DeepNet2(nn.Module):
224 super(DeepNet2, self).__init__()
225 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
226 self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
227 self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
228 self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
229 self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
230 self.fc1 = nn.Linear(4096, 512)
231 self.fc2 = nn.Linear(512, 512)
232 self.fc3 = nn.Linear(512, 2)
234 def forward(self, x):
236 x = fn.max_pool2d(x, kernel_size=2)
240 x = fn.max_pool2d(x, kernel_size=2)
250 x = fn.max_pool2d(x, kernel_size=2)
265 ######################################################################
267 class DeepNet3(nn.Module):
271 super(DeepNet3, self).__init__()
272 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
273 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
274 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
275 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
276 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
277 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
278 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
279 self.fc1 = nn.Linear(2048, 256)
280 self.fc2 = nn.Linear(256, 256)
281 self.fc3 = nn.Linear(256, 2)
283 def forward(self, x):
285 x = fn.max_pool2d(x, kernel_size=2)
289 x = fn.max_pool2d(x, kernel_size=2)
299 x = fn.max_pool2d(x, kernel_size=2)
320 ######################################################################
322 def nb_errors(model, data_set):
324 for b in range(0, data_set.nb_batches):
325 input, target = data_set.get_batch(b)
326 output = model.forward(Variable(input))
327 wta_prediction = output.data.max(1)[1].view(-1)
329 for i in range(0, data_set.batch_size):
330 if wta_prediction[i] != target[i]:
335 ######################################################################
337 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
338 batch_size = args.batch_size
339 criterion = nn.CrossEntropyLoss()
341 if torch.cuda.is_available():
344 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
346 start_t = time.time()
348 for e in range(nb_epochs_done, args.nb_epochs):
350 for b in range(0, train_set.nb_batches):
351 input, target = train_set.get_batch(b)
352 output = model.forward(Variable(input))
353 loss = criterion(output, Variable(target))
354 acc_loss = acc_loss + loss.data[0]
358 dt = (time.time() - start_t) / (e + 1)
360 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
361 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
363 torch.save([ model.state_dict(), e + 1 ], model_filename)
365 if validation_set is not None:
366 nb_validation_errors = nb_errors(model, validation_set)
368 log_string('validation_error {:.02f}% {:d} {:d}'.format(
369 100 * nb_validation_errors / validation_set.nb_samples,
370 nb_validation_errors,
371 validation_set.nb_samples)
374 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
375 log_string('below validation_error_threshold')
380 ######################################################################
382 for arg in vars(args):
383 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
385 ######################################################################
387 def int_to_suffix(n):
388 if n >= 1000000 and n%1000000 == 0:
389 return str(n//1000000) + 'M'
390 elif n >= 1000 and n%1000 == 0:
391 return str(n//1000) + 'K'
395 class vignette_logger():
396 def __init__(self, delay_min = 60):
397 self.start_t = time.time()
398 self.last_t = self.start_t
399 self.delay_min = delay_min
401 def __call__(self, n, m):
403 if t > self.last_t + self.delay_min:
404 dt = (t - self.start_t) / m
405 log_string('sample_generation {:d} / {:d}'.format(
407 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
411 def save_examplar_vignettes(data_set, nb, name):
412 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
414 for k in range(0, nb):
415 b = n[k] // data_set.batch_size
416 m = n[k] % data_set.batch_size
417 i, t = data_set.get_batch(b)
421 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
422 patchwork[k].copy_(i)
424 torchvision.utils.save_image(patchwork, name)
426 ######################################################################
428 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
429 print('The number of samples must be a multiple of the batch size.')
432 log_string('############### start ###############')
434 if args.compress_vignettes:
435 log_string('using_compressed_vignettes')
436 VignetteSet = svrtset.CompressedVignetteSet
438 log_string('using_uncompressed_vignettes')
439 VignetteSet = svrtset.VignetteSet
441 ########################################
443 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
444 if args.model == m.name:
447 if model_class is None:
448 print('Unknown model ' + args.model)
451 log_string('using model class ' + m.name)
452 ########################################
454 for problem_number in map(int, args.problems.split(',')):
456 log_string('############### problem ' + str(problem_number) + ' ###############')
458 model = model_class()
460 if torch.cuda.is_available(): model.cuda()
462 model_filename = model.name + '_pb:' + \
463 str(problem_number) + '_ns:' + \
464 int_to_suffix(args.nb_train_samples) + '.state'
467 for p in model.parameters(): nb_parameters += p.numel()
468 log_string('nb_parameters {:d}'.format(nb_parameters))
470 ##################################################
471 # Tries to load the model
474 model_state_dict, nb_epochs_done = torch.load(model_filename)
475 model.load_state_dict(model_state_dict)
476 log_string('loaded_model ' + model_filename)
481 ##################################################
484 if nb_epochs_done < args.nb_epochs:
486 log_string('training_model ' + model_filename)
490 train_set = VignetteSet(problem_number,
491 args.nb_train_samples, args.batch_size,
492 cuda = torch.cuda.is_available(),
493 logger = vignette_logger())
495 log_string('data_generation {:0.2f} samples / s'.format(
496 train_set.nb_samples / (time.time() - t))
499 if args.nb_exemplar_vignettes > 0:
500 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
501 'examplar_{:d}.png'.format(problem_number))
503 if args.validation_error_threshold > 0.0:
504 validation_set = VignetteSet(problem_number,
505 args.nb_validation_samples, args.batch_size,
506 cuda = torch.cuda.is_available(),
507 logger = vignette_logger())
509 validation_set = None
511 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
512 log_string('saved_model ' + model_filename)
514 nb_train_errors = nb_errors(model, train_set)
516 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
518 100 * nb_train_errors / train_set.nb_samples,
520 train_set.nb_samples)
523 ##################################################
526 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
530 test_set = VignetteSet(problem_number,
531 args.nb_test_samples, args.batch_size,
532 cuda = torch.cuda.is_available())
534 nb_test_errors = nb_errors(model, test_set)
536 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
538 100 * nb_test_errors / test_set.nb_samples,
543 ######################################################################