3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
37 from torch import optim
38 from torch import FloatTensor as Tensor
39 from torch.autograd import Variable
41 from torch.nn import functional as fn
42 from torchvision import datasets, transforms, utils
48 ######################################################################
50 parser = argparse.ArgumentParser(
51 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
52 formatter_class = argparse.ArgumentDefaultsHelpFormatter
55 parser.add_argument('--nb_train_samples',
56 type = int, default = 100000)
58 parser.add_argument('--nb_test_samples',
59 type = int, default = 10000)
61 parser.add_argument('--nb_validation_samples',
62 type = int, default = 10000)
64 parser.add_argument('--validation_error_threshold',
65 type = float, default = 0.0,
66 help = 'Early training termination criterion')
68 parser.add_argument('--nb_epochs',
69 type = int, default = 50)
71 parser.add_argument('--batch_size',
72 type = int, default = 100)
74 parser.add_argument('--log_file',
75 type = str, default = 'default.log')
77 parser.add_argument('--nb_exemplar_vignettes',
78 type = int, default = -1)
80 parser.add_argument('--compress_vignettes',
81 type = distutils.util.strtobool, default = 'True',
82 help = 'Use lossless compression to reduce the memory footprint')
84 parser.add_argument('--deep_model',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use Afroze\'s Alexnet-like deep model')
88 parser.add_argument('--test_loaded_models',
89 type = distutils.util.strtobool, default = 'False',
90 help = 'Should we compute the test errors of loaded models')
92 parser.add_argument('--problems',
93 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
94 help = 'What problems to process')
96 args = parser.parse_args()
98 ######################################################################
100 log_file = open(args.log_file, 'a')
102 last_tag_t = time.time()
104 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
106 # Log and prints the string, with a time stamp. Does not log the
109 def log_string(s, remark = ''):
110 global pred_log_t, last_tag_t
114 if pred_log_t is None:
117 elapsed = '+{:.02f}s'.format(t - pred_log_t)
121 if t > last_tag_t + 3600:
123 print(Fore.RED + time.ctime() + Style.RESET_ALL)
125 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
128 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
130 ######################################################################
132 # Afroze's ShallowNet
135 # ----------------------
137 # -- conv(21x21 x 6) -> 108x108 6
138 # -- max(2x2) -> 54x54 6
139 # -- conv(19x19 x 16) -> 36x36 16
140 # -- max(2x2) -> 18x18 16
141 # -- conv(18x18 x 120) -> 1x1 120
142 # -- reshape -> 120 1
143 # -- full(120x84) -> 84 1
144 # -- full(84x2) -> 2 1
146 class AfrozeShallowNet(nn.Module):
148 super(AfrozeShallowNet, self).__init__()
149 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
150 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
151 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
152 self.fc1 = nn.Linear(120, 84)
153 self.fc2 = nn.Linear(84, 2)
154 self.name = 'shallownet'
156 def forward(self, x):
157 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
158 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
159 x = fn.relu(self.conv3(x))
161 x = fn.relu(self.fc1(x))
165 ######################################################################
169 class AfrozeDeepNet(nn.Module):
171 super(AfrozeDeepNet, self).__init__()
172 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
173 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
174 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
175 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
176 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
177 self.fc1 = nn.Linear(1536, 256)
178 self.fc2 = nn.Linear(256, 256)
179 self.fc3 = nn.Linear(256, 2)
180 self.name = 'deepnet'
182 def forward(self, x):
184 x = fn.max_pool2d(x, kernel_size=2)
188 x = fn.max_pool2d(x, kernel_size=2)
198 x = fn.max_pool2d(x, kernel_size=2)
213 ######################################################################
215 def nb_errors(model, data_set):
217 for b in range(0, data_set.nb_batches):
218 input, target = data_set.get_batch(b)
219 output = model.forward(Variable(input))
220 wta_prediction = output.data.max(1)[1].view(-1)
222 for i in range(0, data_set.batch_size):
223 if wta_prediction[i] != target[i]:
228 ######################################################################
230 def train_model(model, train_set, validation_set):
231 batch_size = args.batch_size
232 criterion = nn.CrossEntropyLoss()
234 if torch.cuda.is_available():
237 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
239 start_t = time.time()
241 for e in range(0, args.nb_epochs):
243 for b in range(0, train_set.nb_batches):
244 input, target = train_set.get_batch(b)
245 output = model.forward(Variable(input))
246 loss = criterion(output, Variable(target))
247 acc_loss = acc_loss + loss.data[0]
251 dt = (time.time() - start_t) / (e + 1)
253 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
254 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
256 if validation_set is not None:
257 nb_validation_errors = nb_errors(model, validation_set)
259 log_string('validation_error {:.02f}% {:d} {:d}'.format(
260 100 * nb_validation_errors / validation_set.nb_samples,
261 nb_validation_errors,
262 validation_set.nb_samples)
265 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
266 log_string('below validation_error_threshold')
271 ######################################################################
273 for arg in vars(args):
274 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
276 ######################################################################
278 def int_to_suffix(n):
279 if n >= 1000000 and n%1000000 == 0:
280 return str(n//1000000) + 'M'
281 elif n >= 1000 and n%1000 == 0:
282 return str(n//1000) + 'K'
286 class vignette_logger():
287 def __init__(self, delay_min = 60):
288 self.start_t = time.time()
289 self.last_t = self.start_t
290 self.delay_min = delay_min
292 def __call__(self, n, m):
294 if t > self.last_t + self.delay_min:
295 dt = (t - self.start_t) / m
296 log_string('sample_generation {:d} / {:d}'.format(
298 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
302 def save_examplar_vignettes(data_set, nb, name):
303 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
305 for k in range(0, nb):
306 b = n[k] // data_set.batch_size
307 m = n[k] % data_set.batch_size
308 i, t = data_set.get_batch(b)
312 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
313 patchwork[k].copy_(i)
315 torchvision.utils.save_image(patchwork, name)
317 ######################################################################
319 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
320 print('The number of samples must be a multiple of the batch size.')
323 log_string('############### start ###############')
325 if args.compress_vignettes:
326 log_string('using_compressed_vignettes')
327 VignetteSet = svrtset.CompressedVignetteSet
329 log_string('using_uncompressed_vignettes')
330 VignetteSet = svrtset.VignetteSet
332 for problem_number in map(int, args.problems.split(',')):
334 log_string('############### problem ' + str(problem_number) + ' ###############')
337 model = AfrozeDeepNet()
339 model = AfrozeShallowNet()
341 if torch.cuda.is_available(): model.cuda()
343 model_filename = model.name + '_pb:' + \
344 str(problem_number) + '_ns:' + \
345 int_to_suffix(args.nb_train_samples) + '.param'
348 for p in model.parameters(): nb_parameters += p.numel()
349 log_string('nb_parameters {:d}'.format(nb_parameters))
351 ##################################################
352 # Tries to load the model
354 need_to_train = False
356 model.load_state_dict(torch.load(model_filename))
357 log_string('loaded_model ' + model_filename)
361 ##################################################
366 log_string('training_model ' + model_filename)
370 train_set = VignetteSet(problem_number,
371 args.nb_train_samples, args.batch_size,
372 cuda = torch.cuda.is_available(),
373 logger = vignette_logger())
375 log_string('data_generation {:0.2f} samples / s'.format(
376 train_set.nb_samples / (time.time() - t))
379 if args.nb_exemplar_vignettes > 0:
380 save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
381 'examplar_{:d}.png'.format(problem_number))
383 if args.validation_error_threshold > 0.0:
384 validation_set = VignetteSet(problem_number,
385 args.nb_validation_samples, args.batch_size,
386 cuda = torch.cuda.is_available(),
387 logger = vignette_logger())
389 validation_set = None
391 train_model(model, train_set, validation_set)
392 torch.save(model.state_dict(), model_filename)
393 log_string('saved_model ' + model_filename)
395 nb_train_errors = nb_errors(model, train_set)
397 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
399 100 * nb_train_errors / train_set.nb_samples,
401 train_set.nb_samples)
404 ##################################################
407 if need_to_train or args.test_loaded_models:
411 test_set = VignetteSet(problem_number,
412 args.nb_test_samples, args.batch_size,
413 cuda = torch.cuda.is_available())
415 nb_test_errors = nb_errors(model, test_set)
417 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
419 100 * nb_test_errors / test_set.nb_samples,
424 ######################################################################