3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
42 from torch.nn import functional as fn
44 from torchvision import datasets, transforms, utils
50 ######################################################################
52 parser = argparse.ArgumentParser(
53 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54 formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 parser.add_argument('--nb_train_samples',
58 type = int, default = 100000)
60 parser.add_argument('--nb_test_samples',
61 type = int, default = 10000)
63 parser.add_argument('--nb_validation_samples',
64 type = int, default = 10000)
66 parser.add_argument('--validation_error_threshold',
67 type = float, default = 0.0,
68 help = 'Early training termination criterion')
70 parser.add_argument('--nb_epochs',
71 type = int, default = 50)
73 parser.add_argument('--batch_size',
74 type = int, default = 100)
76 parser.add_argument('--log_file',
77 type = str, default = 'default.log')
79 parser.add_argument('--nb_exemplar_vignettes',
80 type = int, default = -1)
82 parser.add_argument('--compress_vignettes',
83 type = distutils.util.strtobool, default = 'True',
84 help = 'Use lossless compression to reduce the memory footprint')
86 parser.add_argument('--deep_model',
87 type = distutils.util.strtobool, default = 'True',
88 help = 'Use Afroze\'s Alexnet-like deep model')
90 parser.add_argument('--test_loaded_models',
91 type = distutils.util.strtobool, default = 'False',
92 help = 'Should we compute the test errors of loaded models')
94 parser.add_argument('--problems',
95 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96 help = 'What problems to process')
98 args = parser.parse_args()
100 ######################################################################
102 log_file = open(args.log_file, 'a')
104 last_tag_t = time.time()
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
108 # Log and prints the string, with a time stamp. Does not log the
111 def log_string(s, remark = ''):
112 global pred_log_t, last_tag_t
116 if pred_log_t is None:
119 elapsed = '+{:.02f}s'.format(t - pred_log_t)
123 if t > last_tag_t + 3600:
125 print(Fore.RED + time.ctime() + Style.RESET_ALL)
127 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
130 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
132 ######################################################################
134 # Afroze's ShallowNet
137 # ----------------------
139 # -- conv(21x21 x 6) -> 108x108 6
140 # -- max(2x2) -> 54x54 6
141 # -- conv(19x19 x 16) -> 36x36 16
142 # -- max(2x2) -> 18x18 16
143 # -- conv(18x18 x 120) -> 1x1 120
144 # -- reshape -> 120 1
145 # -- full(120x84) -> 84 1
146 # -- full(84x2) -> 2 1
148 class AfrozeShallowNet(nn.Module):
150 super(AfrozeShallowNet, self).__init__()
151 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
152 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
153 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
154 self.fc1 = nn.Linear(120, 84)
155 self.fc2 = nn.Linear(84, 2)
156 self.name = 'shallownet'
158 def forward(self, x):
159 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
160 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
161 x = fn.relu(self.conv3(x))
163 x = fn.relu(self.fc1(x))
167 ######################################################################
171 class AfrozeDeepNet(nn.Module):
173 super(AfrozeDeepNet, self).__init__()
174 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
175 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
176 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
177 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
178 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
179 self.fc1 = nn.Linear(1536, 256)
180 self.fc2 = nn.Linear(256, 256)
181 self.fc3 = nn.Linear(256, 2)
182 self.name = 'deepnet'
184 def forward(self, x):
186 x = fn.max_pool2d(x, kernel_size=2)
190 x = fn.max_pool2d(x, kernel_size=2)
200 x = fn.max_pool2d(x, kernel_size=2)
215 ######################################################################
217 def nb_errors(model, data_set):
219 for b in range(0, data_set.nb_batches):
220 input, target = data_set.get_batch(b)
221 output = model.forward(Variable(input))
222 wta_prediction = output.data.max(1)[1].view(-1)
224 for i in range(0, data_set.batch_size):
225 if wta_prediction[i] != target[i]:
230 ######################################################################
232 def train_model(model, train_set, validation_set):
233 batch_size = args.batch_size
234 criterion = nn.CrossEntropyLoss()
236 if torch.cuda.is_available():
239 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
241 start_t = time.time()
243 for e in range(0, args.nb_epochs):
245 for b in range(0, train_set.nb_batches):
246 input, target = train_set.get_batch(b)
247 output = model.forward(Variable(input))
248 loss = criterion(output, Variable(target))
249 acc_loss = acc_loss + loss.data[0]
253 dt = (time.time() - start_t) / (e + 1)
255 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
256 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
258 if validation_set is not None:
259 nb_validation_errors = nb_errors(model, validation_set)
261 log_string('validation_error {:.02f}% {:d} {:d}'.format(
262 100 * nb_validation_errors / validation_set.nb_samples,
263 nb_validation_errors,
264 validation_set.nb_samples)
267 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
268 log_string('below validation_error_threshold')
273 ######################################################################
275 for arg in vars(args):
276 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
278 ######################################################################
280 def int_to_suffix(n):
281 if n >= 1000000 and n%1000000 == 0:
282 return str(n//1000000) + 'M'
283 elif n >= 1000 and n%1000 == 0:
284 return str(n//1000) + 'K'
288 class vignette_logger():
289 def __init__(self, delay_min = 60):
290 self.start_t = time.time()
291 self.last_t = self.start_t
292 self.delay_min = delay_min
294 def __call__(self, n, m):
296 if t > self.last_t + self.delay_min:
297 dt = (t - self.start_t) / m
298 log_string('sample_generation {:d} / {:d}'.format(
300 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
304 def save_examplar_vignettes(data_set, nb, name):
305 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
307 for k in range(0, nb):
308 b = n[k] // data_set.batch_size
309 m = n[k] % data_set.batch_size
310 i, t = data_set.get_batch(b)
314 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
315 patchwork[k].copy_(i)
317 torchvision.utils.save_image(patchwork, name)
319 ######################################################################
321 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
322 print('The number of samples must be a multiple of the batch size.')
325 log_string('############### start ###############')
327 if args.compress_vignettes:
328 log_string('using_compressed_vignettes')
329 VignetteSet = svrtset.CompressedVignetteSet
331 log_string('using_uncompressed_vignettes')
332 VignetteSet = svrtset.VignetteSet
334 for problem_number in map(int, args.problems.split(',')):
336 log_string('############### problem ' + str(problem_number) + ' ###############')
339 model = AfrozeDeepNet()
341 model = AfrozeShallowNet()
343 if torch.cuda.is_available(): model.cuda()
345 model_filename = model.name + '_pb:' + \
346 str(problem_number) + '_ns:' + \
347 int_to_suffix(args.nb_train_samples) + '.param'
350 for p in model.parameters(): nb_parameters += p.numel()
351 log_string('nb_parameters {:d}'.format(nb_parameters))
353 ##################################################
354 # Tries to load the model
356 need_to_train = False
358 model.load_state_dict(torch.load(model_filename))
359 log_string('loaded_model ' + model_filename)
363 ##################################################
368 log_string('training_model ' + model_filename)
372 train_set = VignetteSet(problem_number,
373 args.nb_train_samples, args.batch_size,
374 cuda = torch.cuda.is_available(),
375 logger = vignette_logger())
377 log_string('data_generation {:0.2f} samples / s'.format(
378 train_set.nb_samples / (time.time() - t))
381 if args.nb_exemplar_vignettes > 0:
382 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
383 'examplar_{:d}.png'.format(problem_number))
385 if args.validation_error_threshold > 0.0:
386 validation_set = VignetteSet(problem_number,
387 args.nb_validation_samples, args.batch_size,
388 cuda = torch.cuda.is_available(),
389 logger = vignette_logger())
391 validation_set = None
393 train_model(model, train_set, validation_set)
394 torch.save(model.state_dict(), model_filename)
395 log_string('saved_model ' + model_filename)
397 nb_train_errors = nb_errors(model, train_set)
399 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
401 100 * nb_train_errors / train_set.nb_samples,
403 train_set.nb_samples)
406 ##################################################
409 if need_to_train or args.test_loaded_models:
413 test_set = VignetteSet(problem_number,
414 args.nb_test_samples, args.batch_size,
415 cuda = torch.cuda.is_available())
417 nb_test_errors = nb_errors(model, test_set)
419 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
421 100 * nb_test_errors / test_set.nb_samples,
426 ######################################################################