3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
42 from torch.nn import functional as fn
44 from torchvision import datasets, transforms, utils
50 ######################################################################
52 parser = argparse.ArgumentParser(
53 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54 formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 parser.add_argument('--nb_train_samples',
58 type = int, default = 100000)
60 parser.add_argument('--nb_test_samples',
61 type = int, default = 10000)
63 parser.add_argument('--nb_validation_samples',
64 type = int, default = 10000)
66 parser.add_argument('--validation_error_threshold',
67 type = float, default = 0.0,
68 help = 'Early training termination criterion')
70 parser.add_argument('--nb_epochs',
71 type = int, default = 50)
73 parser.add_argument('--batch_size',
74 type = int, default = 100)
76 parser.add_argument('--log_file',
77 type = str, default = 'default.log')
79 parser.add_argument('--nb_exemplar_vignettes',
80 type = int, default = 32)
82 parser.add_argument('--compress_vignettes',
83 type = distutils.util.strtobool, default = 'True',
84 help = 'Use lossless compression to reduce the memory footprint')
86 parser.add_argument('--model',
87 type = str, default = 'deepnet',
88 help = 'What model to use')
90 parser.add_argument('--test_loaded_models',
91 type = distutils.util.strtobool, default = 'False',
92 help = 'Should we compute the test errors of loaded models')
94 parser.add_argument('--problems',
95 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96 help = 'What problems to process')
98 args = parser.parse_args()
100 ######################################################################
102 log_file = open(args.log_file, 'a')
104 last_tag_t = time.time()
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
108 # Log and prints the string, with a time stamp. Does not log the
111 def log_string(s, remark = ''):
112 global pred_log_t, last_tag_t
116 if pred_log_t is None:
119 elapsed = '+{:.02f}s'.format(t - pred_log_t)
123 if t > last_tag_t + 3600:
125 print(Fore.RED + time.ctime() + Style.RESET_ALL)
127 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
130 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
132 ######################################################################
134 # Afroze's ShallowNet
137 # ----------------------
139 # -- conv(21x21 x 6) -> 108x108 6
140 # -- max(2x2) -> 54x54 6
141 # -- conv(19x19 x 16) -> 36x36 16
142 # -- max(2x2) -> 18x18 16
143 # -- conv(18x18 x 120) -> 1x1 120
144 # -- reshape -> 120 1
145 # -- full(120x84) -> 84 1
146 # -- full(84x2) -> 2 1
148 class AfrozeShallowNet(nn.Module):
152 super(AfrozeShallowNet, self).__init__()
153 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156 self.fc1 = nn.Linear(120, 84)
157 self.fc2 = nn.Linear(84, 2)
159 def forward(self, x):
160 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162 x = fn.relu(self.conv3(x))
164 x = fn.relu(self.fc1(x))
168 ######################################################################
172 class AfrozeDeepNet(nn.Module):
177 super(AfrozeDeepNet, self).__init__()
178 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
179 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
180 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
183 self.fc1 = nn.Linear(1536, 256)
184 self.fc2 = nn.Linear(256, 256)
185 self.fc3 = nn.Linear(256, 2)
187 def forward(self, x):
189 x = fn.max_pool2d(x, kernel_size=2)
193 x = fn.max_pool2d(x, kernel_size=2)
203 x = fn.max_pool2d(x, kernel_size=2)
218 ######################################################################
220 class DeepNet2(nn.Module):
224 super(DeepNet2, self).__init__()
225 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
226 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
227 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
228 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
229 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
230 self.fc1 = nn.Linear(2048, 512)
231 self.fc2 = nn.Linear(512, 512)
232 self.fc3 = nn.Linear(256, 2)
234 def forward(self, x):
236 x = fn.max_pool2d(x, kernel_size=2)
240 x = fn.max_pool2d(x, kernel_size=2)
250 x = fn.max_pool2d(x, kernel_size=2)
265 ######################################################################
267 def nb_errors(model, data_set):
269 for b in range(0, data_set.nb_batches):
270 input, target = data_set.get_batch(b)
271 output = model.forward(Variable(input))
272 wta_prediction = output.data.max(1)[1].view(-1)
274 for i in range(0, data_set.batch_size):
275 if wta_prediction[i] != target[i]:
280 ######################################################################
282 def train_model(model, train_set, validation_set):
283 batch_size = args.batch_size
284 criterion = nn.CrossEntropyLoss()
286 if torch.cuda.is_available():
289 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
291 start_t = time.time()
293 for e in range(0, args.nb_epochs):
295 for b in range(0, train_set.nb_batches):
296 input, target = train_set.get_batch(b)
297 output = model.forward(Variable(input))
298 loss = criterion(output, Variable(target))
299 acc_loss = acc_loss + loss.data[0]
303 dt = (time.time() - start_t) / (e + 1)
305 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
306 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
308 if validation_set is not None:
309 nb_validation_errors = nb_errors(model, validation_set)
311 log_string('validation_error {:.02f}% {:d} {:d}'.format(
312 100 * nb_validation_errors / validation_set.nb_samples,
313 nb_validation_errors,
314 validation_set.nb_samples)
317 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
318 log_string('below validation_error_threshold')
323 ######################################################################
325 for arg in vars(args):
326 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
328 ######################################################################
330 def int_to_suffix(n):
331 if n >= 1000000 and n%1000000 == 0:
332 return str(n//1000000) + 'M'
333 elif n >= 1000 and n%1000 == 0:
334 return str(n//1000) + 'K'
338 class vignette_logger():
339 def __init__(self, delay_min = 60):
340 self.start_t = time.time()
341 self.last_t = self.start_t
342 self.delay_min = delay_min
344 def __call__(self, n, m):
346 if t > self.last_t + self.delay_min:
347 dt = (t - self.start_t) / m
348 log_string('sample_generation {:d} / {:d}'.format(
350 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
354 def save_examplar_vignettes(data_set, nb, name):
355 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
357 for k in range(0, nb):
358 b = n[k] // data_set.batch_size
359 m = n[k] % data_set.batch_size
360 i, t = data_set.get_batch(b)
364 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
365 patchwork[k].copy_(i)
367 torchvision.utils.save_image(patchwork, name)
369 ######################################################################
371 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
372 print('The number of samples must be a multiple of the batch size.')
375 log_string('############### start ###############')
377 if args.compress_vignettes:
378 log_string('using_compressed_vignettes')
379 VignetteSet = svrtset.CompressedVignetteSet
381 log_string('using_uncompressed_vignettes')
382 VignetteSet = svrtset.VignetteSet
384 ########################################
386 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
387 if args.model == m.name:
390 if model_class is None:
391 print('Unknown model ' + args.model)
394 log_string('using model class ' + m.name)
395 ########################################
397 for problem_number in map(int, args.problems.split(',')):
399 log_string('############### problem ' + str(problem_number) + ' ###############')
401 model = model_class()
403 if torch.cuda.is_available(): model.cuda()
405 model_filename = model.name + '_pb:' + \
406 str(problem_number) + '_ns:' + \
407 int_to_suffix(args.nb_train_samples) + '.param'
410 for p in model.parameters(): nb_parameters += p.numel()
411 log_string('nb_parameters {:d}'.format(nb_parameters))
413 ##################################################
414 # Tries to load the model
416 need_to_train = False
418 model.load_state_dict(torch.load(model_filename))
419 log_string('loaded_model ' + model_filename)
423 ##################################################
428 log_string('training_model ' + model_filename)
432 train_set = VignetteSet(problem_number,
433 args.nb_train_samples, args.batch_size,
434 cuda = torch.cuda.is_available(),
435 logger = vignette_logger())
437 log_string('data_generation {:0.2f} samples / s'.format(
438 train_set.nb_samples / (time.time() - t))
441 if args.nb_exemplar_vignettes > 0:
442 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
443 'examplar_{:d}.png'.format(problem_number))
445 if args.validation_error_threshold > 0.0:
446 validation_set = VignetteSet(problem_number,
447 args.nb_validation_samples, args.batch_size,
448 cuda = torch.cuda.is_available(),
449 logger = vignette_logger())
451 validation_set = None
453 train_model(model, train_set, validation_set)
454 torch.save(model.state_dict(), model_filename)
455 log_string('saved_model ' + model_filename)
457 nb_train_errors = nb_errors(model, train_set)
459 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
461 100 * nb_train_errors / train_set.nb_samples,
463 train_set.nb_samples)
466 ##################################################
469 if need_to_train or args.test_loaded_models:
473 test_set = VignetteSet(problem_number,
474 args.nb_test_samples, args.batch_size,
475 cuda = torch.cuda.is_available())
477 nb_test_errors = nb_errors(model, test_set)
479 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
481 100 * nb_test_errors / test_set.nb_samples,
486 ######################################################################