3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_validation_samples',
60 type = int, default = 10000)
62 parser.add_argument('--validation_error_threshold',
63 type = float, default = 0.0,
64 help = 'Early training termination criterion')
66 parser.add_argument('--nb_epochs',
67 type = int, default = 50)
69 parser.add_argument('--batch_size',
70 type = int, default = 100)
72 parser.add_argument('--log_file',
73 type = str, default = 'default.log')
75 parser.add_argument('--compress_vignettes',
76 type = distutils.util.strtobool, default = 'True',
77 help = 'Use lossless compression to reduce the memory footprint')
79 parser.add_argument('--deep_model',
80 type = distutils.util.strtobool, default = 'True',
81 help = 'Use Afroze\'s Alexnet-like deep model')
83 parser.add_argument('--test_loaded_models',
84 type = distutils.util.strtobool, default = 'False',
85 help = 'Should we compute the test errors of loaded models')
87 parser.add_argument('--problems',
88 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
89 help = 'What problems to process')
91 args = parser.parse_args()
93 ######################################################################
95 log_file = open(args.log_file, 'a')
98 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
100 # Log and prints the string, with a time stamp. Does not log the
102 def log_string(s, remark = ''):
107 if pred_log_t is None:
110 elapsed = '+{:.02f}s'.format(t - pred_log_t)
114 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
117 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
119 ######################################################################
121 # Afroze's ShallowNet
124 # ----------------------
126 # -- conv(21x21 x 6) -> 108x108 6
127 # -- max(2x2) -> 54x54 6
128 # -- conv(19x19 x 16) -> 36x36 16
129 # -- max(2x2) -> 18x18 16
130 # -- conv(18x18 x 120) -> 1x1 120
131 # -- reshape -> 120 1
132 # -- full(120x84) -> 84 1
133 # -- full(84x2) -> 2 1
135 class AfrozeShallowNet(nn.Module):
137 super(AfrozeShallowNet, self).__init__()
138 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
139 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
140 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
141 self.fc1 = nn.Linear(120, 84)
142 self.fc2 = nn.Linear(84, 2)
143 self.name = 'shallownet'
145 def forward(self, x):
146 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
147 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
148 x = fn.relu(self.conv3(x))
150 x = fn.relu(self.fc1(x))
154 ######################################################################
158 class AfrozeDeepNet(nn.Module):
160 super(AfrozeDeepNet, self).__init__()
161 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
162 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
163 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
164 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
165 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
166 self.fc1 = nn.Linear(1536, 256)
167 self.fc2 = nn.Linear(256, 256)
168 self.fc3 = nn.Linear(256, 2)
169 self.name = 'deepnet'
171 def forward(self, x):
173 x = fn.max_pool2d(x, kernel_size=2)
177 x = fn.max_pool2d(x, kernel_size=2)
187 x = fn.max_pool2d(x, kernel_size=2)
202 ######################################################################
204 def nb_errors(model, data_set):
206 for b in range(0, data_set.nb_batches):
207 input, target = data_set.get_batch(b)
208 output = model.forward(Variable(input))
209 wta_prediction = output.data.max(1)[1].view(-1)
211 for i in range(0, data_set.batch_size):
212 if wta_prediction[i] != target[i]:
217 ######################################################################
219 def train_model(model, train_set, validation_set):
220 batch_size = args.batch_size
221 criterion = nn.CrossEntropyLoss()
223 if torch.cuda.is_available():
226 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
228 start_t = time.time()
230 for e in range(0, args.nb_epochs):
232 for b in range(0, train_set.nb_batches):
233 input, target = train_set.get_batch(b)
234 output = model.forward(Variable(input))
235 loss = criterion(output, Variable(target))
236 acc_loss = acc_loss + loss.data[0]
240 dt = (time.time() - start_t) / (e + 1)
242 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
243 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
245 if validation_set is not None:
246 nb_validation_errors = nb_errors(model, validation_set)
248 log_string('validation_error {:.02f}% {:d} {:d}'.format(
249 100 * nb_validation_errors / validation_set.nb_samples,
250 nb_validation_errors,
251 validation_set.nb_samples)
254 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
255 log_string('below validation_error_threshold')
260 ######################################################################
262 for arg in vars(args):
263 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
265 ######################################################################
267 def int_to_suffix(n):
268 if n >= 1000000 and n%1000000 == 0:
269 return str(n//1000000) + 'M'
270 elif n >= 1000 and n%1000 == 0:
271 return str(n//1000) + 'K'
275 class vignette_logger():
276 def __init__(self, delay_min = 60):
277 self.start_t = time.time()
278 self.last_t = self.start_t
279 self.delay_min = delay_min
281 def __call__(self, n, m):
283 if t > self.last_t + self.delay_min:
284 dt = (t - self.start_t) / m
285 log_string('sample_generation {:d} / {:d}'.format(
287 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
291 ######################################################################
293 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
294 print('The number of samples must be a multiple of the batch size.')
297 log_string('############### start ###############')
299 if args.compress_vignettes:
300 log_string('using_compressed_vignettes')
301 VignetteSet = svrtset.CompressedVignetteSet
303 log_string('using_uncompressed_vignettes')
304 VignetteSet = svrtset.VignetteSet
306 for problem_number in map(int, args.problems.split(',')):
308 log_string('############### problem ' + str(problem_number) + ' ###############')
311 model = AfrozeDeepNet()
313 model = AfrozeShallowNet()
315 if torch.cuda.is_available(): model.cuda()
317 model_filename = model.name + '_pb:' + \
318 str(problem_number) + '_ns:' + \
319 int_to_suffix(args.nb_train_samples) + '.param'
322 for p in model.parameters(): nb_parameters += p.numel()
323 log_string('nb_parameters {:d}'.format(nb_parameters))
325 ##################################################
326 # Tries to load the model
328 need_to_train = False
330 model.load_state_dict(torch.load(model_filename))
331 log_string('loaded_model ' + model_filename)
335 ##################################################
340 log_string('training_model ' + model_filename)
344 train_set = VignetteSet(problem_number,
345 args.nb_train_samples, args.batch_size,
346 cuda = torch.cuda.is_available(),
347 logger = vignette_logger())
349 log_string('data_generation {:0.2f} samples / s'.format(
350 train_set.nb_samples / (time.time() - t))
353 if args.validation_error_threshold > 0.0:
354 validation_set = VignetteSet(problem_number,
355 args.nb_validation_samples, args.batch_size,
356 cuda = torch.cuda.is_available(),
357 logger = vignette_logger())
359 validation_set = None
361 train_model(model, train_set, validation_set)
362 torch.save(model.state_dict(), model_filename)
363 log_string('saved_model ' + model_filename)
365 nb_train_errors = nb_errors(model, train_set)
367 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
369 100 * nb_train_errors / train_set.nb_samples,
371 train_set.nb_samples)
374 ##################################################
377 if need_to_train or args.test_loaded_models:
381 test_set = VignetteSet(problem_number,
382 args.nb_test_samples, args.batch_size,
383 cuda = torch.cuda.is_available())
385 nb_test_errors = nb_errors(model, test_set)
387 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
389 100 * nb_test_errors / test_set.nb_samples,
394 ######################################################################