3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
42 from torch.nn import functional as fn
44 from torchvision import datasets, transforms, utils
50 ######################################################################
52 parser = argparse.ArgumentParser(
53 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54 formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 parser.add_argument('--nb_train_samples',
58 type = int, default = 100000)
60 parser.add_argument('--nb_test_samples',
61 type = int, default = 10000)
63 parser.add_argument('--nb_validation_samples',
64 type = int, default = 10000)
66 parser.add_argument('--validation_error_threshold',
67 type = float, default = 0.0,
68 help = 'Early training termination criterion')
70 parser.add_argument('--nb_epochs',
71 type = int, default = 50)
73 parser.add_argument('--batch_size',
74 type = int, default = 100)
76 parser.add_argument('--log_file',
77 type = str, default = 'default.log')
79 parser.add_argument('--nb_exemplar_vignettes',
80 type = int, default = 32)
82 parser.add_argument('--compress_vignettes',
83 type = distutils.util.strtobool, default = 'True',
84 help = 'Use lossless compression to reduce the memory footprint')
86 parser.add_argument('--model',
87 type = str, default = 'deepnet',
88 help = 'What model to use')
90 parser.add_argument('--test_loaded_models',
91 type = distutils.util.strtobool, default = 'False',
92 help = 'Should we compute the test errors of loaded models')
94 parser.add_argument('--problems',
95 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96 help = 'What problems to process')
98 args = parser.parse_args()
100 ######################################################################
102 log_file = open(args.log_file, 'a')
104 last_tag_t = time.time()
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
108 # Log and prints the string, with a time stamp. Does not log the
111 def log_string(s, remark = ''):
112 global pred_log_t, last_tag_t
116 if pred_log_t is None:
119 elapsed = '+{:.02f}s'.format(t - pred_log_t)
123 if t > last_tag_t + 3600:
125 print(Fore.RED + time.ctime() + Style.RESET_ALL)
127 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
130 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
132 ######################################################################
134 # Afroze's ShallowNet
137 # ----------------------
139 # -- conv(21x21 x 6) -> 108x108 6
140 # -- max(2x2) -> 54x54 6
141 # -- conv(19x19 x 16) -> 36x36 16
142 # -- max(2x2) -> 18x18 16
143 # -- conv(18x18 x 120) -> 1x1 120
144 # -- reshape -> 120 1
145 # -- full(120x84) -> 84 1
146 # -- full(84x2) -> 2 1
148 class AfrozeShallowNet(nn.Module):
152 super(AfrozeShallowNet, self).__init__()
153 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156 self.fc1 = nn.Linear(120, 84)
157 self.fc2 = nn.Linear(84, 2)
159 def forward(self, x):
160 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162 x = fn.relu(self.conv3(x))
164 x = fn.relu(self.fc1(x))
168 ######################################################################
172 class AfrozeDeepNet(nn.Module):
177 super(AfrozeDeepNet, self).__init__()
178 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
179 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
180 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
183 self.fc1 = nn.Linear(1536, 256)
184 self.fc2 = nn.Linear(256, 256)
185 self.fc3 = nn.Linear(256, 2)
187 def forward(self, x):
189 x = fn.max_pool2d(x, kernel_size=2)
193 x = fn.max_pool2d(x, kernel_size=2)
203 x = fn.max_pool2d(x, kernel_size=2)
218 ######################################################################
220 class DeepNet2(nn.Module):
224 super(DeepNet2, self).__init__()
225 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
226 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
227 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
228 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
229 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
230 self.fc1 = nn.Linear(2048, 512)
231 self.fc2 = nn.Linear(512, 512)
232 self.fc3 = nn.Linear(512, 2)
234 def forward(self, x):
236 x = fn.max_pool2d(x, kernel_size=2)
240 x = fn.max_pool2d(x, kernel_size=2)
250 x = fn.max_pool2d(x, kernel_size=2)
265 ######################################################################
267 def nb_errors(model, data_set):
269 for b in range(0, data_set.nb_batches):
270 input, target = data_set.get_batch(b)
271 output = model.forward(Variable(input))
272 wta_prediction = output.data.max(1)[1].view(-1)
274 for i in range(0, data_set.batch_size):
275 if wta_prediction[i] != target[i]:
280 ######################################################################
282 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
283 batch_size = args.batch_size
284 criterion = nn.CrossEntropyLoss()
286 if torch.cuda.is_available():
289 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
291 start_t = time.time()
293 for e in range(nb_epochs_done, args.nb_epochs):
295 for b in range(0, train_set.nb_batches):
296 input, target = train_set.get_batch(b)
297 output = model.forward(Variable(input))
298 loss = criterion(output, Variable(target))
299 acc_loss = acc_loss + loss.data[0]
303 dt = (time.time() - start_t) / (e + 1)
305 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
306 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
308 torch.save([ model.state_dict(), e + 1 ], model_filename)
310 if validation_set is not None:
311 nb_validation_errors = nb_errors(model, validation_set)
313 log_string('validation_error {:.02f}% {:d} {:d}'.format(
314 100 * nb_validation_errors / validation_set.nb_samples,
315 nb_validation_errors,
316 validation_set.nb_samples)
319 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
320 log_string('below validation_error_threshold')
325 ######################################################################
327 for arg in vars(args):
328 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
330 ######################################################################
332 def int_to_suffix(n):
333 if n >= 1000000 and n%1000000 == 0:
334 return str(n//1000000) + 'M'
335 elif n >= 1000 and n%1000 == 0:
336 return str(n//1000) + 'K'
340 class vignette_logger():
341 def __init__(self, delay_min = 60):
342 self.start_t = time.time()
343 self.last_t = self.start_t
344 self.delay_min = delay_min
346 def __call__(self, n, m):
348 if t > self.last_t + self.delay_min:
349 dt = (t - self.start_t) / m
350 log_string('sample_generation {:d} / {:d}'.format(
352 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
356 def save_examplar_vignettes(data_set, nb, name):
357 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
359 for k in range(0, nb):
360 b = n[k] // data_set.batch_size
361 m = n[k] % data_set.batch_size
362 i, t = data_set.get_batch(b)
366 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
367 patchwork[k].copy_(i)
369 torchvision.utils.save_image(patchwork, name)
371 ######################################################################
373 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
374 print('The number of samples must be a multiple of the batch size.')
377 log_string('############### start ###############')
379 if args.compress_vignettes:
380 log_string('using_compressed_vignettes')
381 VignetteSet = svrtset.CompressedVignetteSet
383 log_string('using_uncompressed_vignettes')
384 VignetteSet = svrtset.VignetteSet
386 ########################################
388 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
389 if args.model == m.name:
392 if model_class is None:
393 print('Unknown model ' + args.model)
396 log_string('using model class ' + m.name)
397 ########################################
399 for problem_number in map(int, args.problems.split(',')):
401 log_string('############### problem ' + str(problem_number) + ' ###############')
403 model = model_class()
405 if torch.cuda.is_available(): model.cuda()
407 model_filename = model.name + '_pb:' + \
408 str(problem_number) + '_ns:' + \
409 int_to_suffix(args.nb_train_samples) + '.state'
412 for p in model.parameters(): nb_parameters += p.numel()
413 log_string('nb_parameters {:d}'.format(nb_parameters))
415 ##################################################
416 # Tries to load the model
418 need_to_train = False
420 model_state_dict, nb_epochs_done = torch.load(model_filename)
421 model.load_state_dict(model_state_dict)
422 log_string('loaded_model ' + model_filename)
427 ##################################################
430 if nb_epochs_done < args.nb_epochs:
432 log_string('training_model ' + model_filename)
436 train_set = VignetteSet(problem_number,
437 args.nb_train_samples, args.batch_size,
438 cuda = torch.cuda.is_available(),
439 logger = vignette_logger())
441 log_string('data_generation {:0.2f} samples / s'.format(
442 train_set.nb_samples / (time.time() - t))
445 if args.nb_exemplar_vignettes > 0:
446 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
447 'examplar_{:d}.png'.format(problem_number))
449 if args.validation_error_threshold > 0.0:
450 validation_set = VignetteSet(problem_number,
451 args.nb_validation_samples, args.batch_size,
452 cuda = torch.cuda.is_available(),
453 logger = vignette_logger())
455 validation_set = None
457 train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
458 log_string('saved_model ' + model_filename)
460 nb_train_errors = nb_errors(model, train_set)
462 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
464 100 * nb_train_errors / train_set.nb_samples,
466 train_set.nb_samples)
469 ##################################################
472 if need_to_train or args.test_loaded_models:
476 test_set = VignetteSet(problem_number,
477 args.nb_test_samples, args.batch_size,
478 cuda = torch.cuda.is_available())
480 nb_test_errors = nb_errors(model, test_set)
482 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
484 100 * nb_test_errors / test_set.nb_samples,
489 ######################################################################