3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
30 from colorama import Fore, Back, Style
36 from torch import optim
37 from torch import FloatTensor as Tensor
38 from torch.autograd import Variable
40 from torch.nn import functional as fn
41 from torchvision import datasets, transforms, utils
47 ######################################################################
49 parser = argparse.ArgumentParser(
50 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
51 formatter_class = argparse.ArgumentDefaultsHelpFormatter
54 parser.add_argument('--nb_train_samples',
55 type = int, default = 100000)
57 parser.add_argument('--nb_test_samples',
58 type = int, default = 10000)
60 parser.add_argument('--nb_validation_samples',
61 type = int, default = 10000)
63 parser.add_argument('--validation_error_threshold',
64 type = float, default = 0.0,
65 help = 'Early training termination criterion')
67 parser.add_argument('--nb_epochs',
68 type = int, default = 50)
70 parser.add_argument('--batch_size',
71 type = int, default = 100)
73 parser.add_argument('--log_file',
74 type = str, default = 'default.log')
76 parser.add_argument('--compress_vignettes',
77 type = distutils.util.strtobool, default = 'True',
78 help = 'Use lossless compression to reduce the memory footprint')
80 parser.add_argument('--deep_model',
81 type = distutils.util.strtobool, default = 'True',
82 help = 'Use Afroze\'s Alexnet-like deep model')
84 parser.add_argument('--test_loaded_models',
85 type = distutils.util.strtobool, default = 'False',
86 help = 'Should we compute the test errors of loaded models')
88 parser.add_argument('--problems',
89 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
90 help = 'What problems to process')
92 args = parser.parse_args()
94 ######################################################################
96 log_file = open(args.log_file, 'a')
98 last_tag_t = time.time()
100 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
102 # Log and prints the string, with a time stamp. Does not log the
105 def log_string(s, remark = ''):
106 global pred_log_t, last_tag_t
110 if pred_log_t is None:
113 elapsed = '+{:.02f}s'.format(t - pred_log_t)
117 if t > last_tag_t + 3600:
119 print(Fore.RED + time.ctime() + Style.RESET_ALL)
121 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
124 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
126 ######################################################################
128 # Afroze's ShallowNet
131 # ----------------------
133 # -- conv(21x21 x 6) -> 108x108 6
134 # -- max(2x2) -> 54x54 6
135 # -- conv(19x19 x 16) -> 36x36 16
136 # -- max(2x2) -> 18x18 16
137 # -- conv(18x18 x 120) -> 1x1 120
138 # -- reshape -> 120 1
139 # -- full(120x84) -> 84 1
140 # -- full(84x2) -> 2 1
142 class AfrozeShallowNet(nn.Module):
144 super(AfrozeShallowNet, self).__init__()
145 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
146 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
147 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
148 self.fc1 = nn.Linear(120, 84)
149 self.fc2 = nn.Linear(84, 2)
150 self.name = 'shallownet'
152 def forward(self, x):
153 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
154 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
155 x = fn.relu(self.conv3(x))
157 x = fn.relu(self.fc1(x))
161 ######################################################################
165 class AfrozeDeepNet(nn.Module):
167 super(AfrozeDeepNet, self).__init__()
168 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
169 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
170 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
171 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
172 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
173 self.fc1 = nn.Linear(1536, 256)
174 self.fc2 = nn.Linear(256, 256)
175 self.fc3 = nn.Linear(256, 2)
176 self.name = 'deepnet'
178 def forward(self, x):
180 x = fn.max_pool2d(x, kernel_size=2)
184 x = fn.max_pool2d(x, kernel_size=2)
194 x = fn.max_pool2d(x, kernel_size=2)
209 ######################################################################
211 def nb_errors(model, data_set):
213 for b in range(0, data_set.nb_batches):
214 input, target = data_set.get_batch(b)
215 output = model.forward(Variable(input))
216 wta_prediction = output.data.max(1)[1].view(-1)
218 for i in range(0, data_set.batch_size):
219 if wta_prediction[i] != target[i]:
224 ######################################################################
226 def train_model(model, train_set, validation_set):
227 batch_size = args.batch_size
228 criterion = nn.CrossEntropyLoss()
230 if torch.cuda.is_available():
233 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
235 start_t = time.time()
237 for e in range(0, args.nb_epochs):
239 for b in range(0, train_set.nb_batches):
240 input, target = train_set.get_batch(b)
241 output = model.forward(Variable(input))
242 loss = criterion(output, Variable(target))
243 acc_loss = acc_loss + loss.data[0]
247 dt = (time.time() - start_t) / (e + 1)
249 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
250 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
252 if validation_set is not None:
253 nb_validation_errors = nb_errors(model, validation_set)
255 log_string('validation_error {:.02f}% {:d} {:d}'.format(
256 100 * nb_validation_errors / validation_set.nb_samples,
257 nb_validation_errors,
258 validation_set.nb_samples)
261 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
262 log_string('below validation_error_threshold')
267 ######################################################################
269 for arg in vars(args):
270 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
272 ######################################################################
274 def int_to_suffix(n):
275 if n >= 1000000 and n%1000000 == 0:
276 return str(n//1000000) + 'M'
277 elif n >= 1000 and n%1000 == 0:
278 return str(n//1000) + 'K'
282 class vignette_logger():
283 def __init__(self, delay_min = 60):
284 self.start_t = time.time()
285 self.last_t = self.start_t
286 self.delay_min = delay_min
288 def __call__(self, n, m):
290 if t > self.last_t + self.delay_min:
291 dt = (t - self.start_t) / m
292 log_string('sample_generation {:d} / {:d}'.format(
294 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
298 ######################################################################
300 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
301 print('The number of samples must be a multiple of the batch size.')
304 log_string('############### start ###############')
306 if args.compress_vignettes:
307 log_string('using_compressed_vignettes')
308 VignetteSet = svrtset.CompressedVignetteSet
310 log_string('using_uncompressed_vignettes')
311 VignetteSet = svrtset.VignetteSet
313 for problem_number in map(int, args.problems.split(',')):
315 log_string('############### problem ' + str(problem_number) + ' ###############')
318 model = AfrozeDeepNet()
320 model = AfrozeShallowNet()
322 if torch.cuda.is_available(): model.cuda()
324 model_filename = model.name + '_pb:' + \
325 str(problem_number) + '_ns:' + \
326 int_to_suffix(args.nb_train_samples) + '.param'
329 for p in model.parameters(): nb_parameters += p.numel()
330 log_string('nb_parameters {:d}'.format(nb_parameters))
332 ##################################################
333 # Tries to load the model
335 need_to_train = False
337 model.load_state_dict(torch.load(model_filename))
338 log_string('loaded_model ' + model_filename)
342 ##################################################
347 log_string('training_model ' + model_filename)
351 train_set = VignetteSet(problem_number,
352 args.nb_train_samples, args.batch_size,
353 cuda = torch.cuda.is_available(),
354 logger = vignette_logger())
356 log_string('data_generation {:0.2f} samples / s'.format(
357 train_set.nb_samples / (time.time() - t))
360 if args.validation_error_threshold > 0.0:
361 validation_set = VignetteSet(problem_number,
362 args.nb_validation_samples, args.batch_size,
363 cuda = torch.cuda.is_available(),
364 logger = vignette_logger())
366 validation_set = None
368 train_model(model, train_set, validation_set)
369 torch.save(model.state_dict(), model_filename)
370 log_string('saved_model ' + model_filename)
372 nb_train_errors = nb_errors(model, train_set)
374 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
376 100 * nb_train_errors / train_set.nb_samples,
378 train_set.nb_samples)
381 ##################################################
384 if need_to_train or args.test_loaded_models:
388 test_set = VignetteSet(problem_number,
389 args.nb_test_samples, args.batch_size,
390 cuda = torch.cuda.is_available())
392 nb_test_errors = nb_errors(model, test_set)
394 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
396 100 * nb_test_errors / test_set.nb_samples,
401 ######################################################################