Update.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35 import torchvision
36
37 from torch import optim
38 from torch import FloatTensor as Tensor
39 from torch.autograd import Variable
40 from torch import nn
41 from torch.nn import functional as fn
42 from torchvision import datasets, transforms, utils
43
44 # SVRT
45
46 import svrtset
47
48 ######################################################################
49
50 parser = argparse.ArgumentParser(
51     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
52     formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 )
54
55 parser.add_argument('--nb_train_samples',
56                     type = int, default = 100000)
57
58 parser.add_argument('--nb_test_samples',
59                     type = int, default = 10000)
60
61 parser.add_argument('--nb_validation_samples',
62                     type = int, default = 10000)
63
64 parser.add_argument('--validation_error_threshold',
65                     type = float, default = 0.0,
66                     help = 'Early training termination criterion')
67
68 parser.add_argument('--nb_epochs',
69                     type = int, default = 50)
70
71 parser.add_argument('--batch_size',
72                     type = int, default = 100)
73
74 parser.add_argument('--log_file',
75                     type = str, default = 'default.log')
76
77 parser.add_argument('--nb_exemplar_vignettes',
78                     type = int, default = -1)
79
80 parser.add_argument('--compress_vignettes',
81                     type = distutils.util.strtobool, default = 'True',
82                     help = 'Use lossless compression to reduce the memory footprint')
83
84 parser.add_argument('--deep_model',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use Afroze\'s Alexnet-like deep model')
87
88 parser.add_argument('--test_loaded_models',
89                     type = distutils.util.strtobool, default = 'False',
90                     help = 'Should we compute the test errors of loaded models')
91
92 parser.add_argument('--problems',
93                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
94                     help = 'What problems to process')
95
96 args = parser.parse_args()
97
98 ######################################################################
99
100 log_file = open(args.log_file, 'a')
101 pred_log_t = None
102 last_tag_t = time.time()
103
104 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
105
106 # Log and prints the string, with a time stamp. Does not log the
107 # remark
108
109 def log_string(s, remark = ''):
110     global pred_log_t, last_tag_t
111
112     t = time.time()
113
114     if pred_log_t is None:
115         elapsed = 'start'
116     else:
117         elapsed = '+{:.02f}s'.format(t - pred_log_t)
118
119     pred_log_t = t
120
121     if t > last_tag_t + 3600:
122         last_tag_t = t
123         print(Fore.RED + time.ctime() + Style.RESET_ALL)
124
125     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
126     log_file.flush()
127
128     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
129
130 ######################################################################
131
132 # Afroze's ShallowNet
133
134 #                       map size   nb. maps
135 #                     ----------------------
136 #    input                128x128    1
137 # -- conv(21x21 x 6)   -> 108x108    6
138 # -- max(2x2)          -> 54x54      6
139 # -- conv(19x19 x 16)  -> 36x36      16
140 # -- max(2x2)          -> 18x18      16
141 # -- conv(18x18 x 120) -> 1x1        120
142 # -- reshape           -> 120        1
143 # -- full(120x84)      -> 84         1
144 # -- full(84x2)        -> 2          1
145
146 class AfrozeShallowNet(nn.Module):
147     def __init__(self):
148         super(AfrozeShallowNet, self).__init__()
149         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
150         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
151         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
152         self.fc1 = nn.Linear(120, 84)
153         self.fc2 = nn.Linear(84, 2)
154         self.name = 'shallownet'
155
156     def forward(self, x):
157         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
158         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
159         x = fn.relu(self.conv3(x))
160         x = x.view(-1, 120)
161         x = fn.relu(self.fc1(x))
162         x = self.fc2(x)
163         return x
164
165 ######################################################################
166
167 # Afroze's DeepNet
168
169 class AfrozeDeepNet(nn.Module):
170     def __init__(self):
171         super(AfrozeDeepNet, self).__init__()
172         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
173         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
174         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
175         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
176         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
177         self.fc1 = nn.Linear(1536, 256)
178         self.fc2 = nn.Linear(256, 256)
179         self.fc3 = nn.Linear(256, 2)
180         self.name = 'deepnet'
181
182     def forward(self, x):
183         x = self.conv1(x)
184         x = fn.max_pool2d(x, kernel_size=2)
185         x = fn.relu(x)
186
187         x = self.conv2(x)
188         x = fn.max_pool2d(x, kernel_size=2)
189         x = fn.relu(x)
190
191         x = self.conv3(x)
192         x = fn.relu(x)
193
194         x = self.conv4(x)
195         x = fn.relu(x)
196
197         x = self.conv5(x)
198         x = fn.max_pool2d(x, kernel_size=2)
199         x = fn.relu(x)
200
201         x = x.view(-1, 1536)
202
203         x = self.fc1(x)
204         x = fn.relu(x)
205
206         x = self.fc2(x)
207         x = fn.relu(x)
208
209         x = self.fc3(x)
210
211         return x
212
213 ######################################################################
214
215 def nb_errors(model, data_set):
216     ne = 0
217     for b in range(0, data_set.nb_batches):
218         input, target = data_set.get_batch(b)
219         output = model.forward(Variable(input))
220         wta_prediction = output.data.max(1)[1].view(-1)
221
222         for i in range(0, data_set.batch_size):
223             if wta_prediction[i] != target[i]:
224                 ne = ne + 1
225
226     return ne
227
228 ######################################################################
229
230 def train_model(model, train_set, validation_set):
231     batch_size = args.batch_size
232     criterion = nn.CrossEntropyLoss()
233
234     if torch.cuda.is_available():
235         criterion.cuda()
236
237     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
238
239     start_t = time.time()
240
241     for e in range(0, args.nb_epochs):
242         acc_loss = 0.0
243         for b in range(0, train_set.nb_batches):
244             input, target = train_set.get_batch(b)
245             output = model.forward(Variable(input))
246             loss = criterion(output, Variable(target))
247             acc_loss = acc_loss + loss.data[0]
248             model.zero_grad()
249             loss.backward()
250             optimizer.step()
251         dt = (time.time() - start_t) / (e + 1)
252
253         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
254                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
255
256         if validation_set is not None:
257             nb_validation_errors = nb_errors(model, validation_set)
258
259             log_string('validation_error {:.02f}% {:d} {:d}'.format(
260                 100 * nb_validation_errors / validation_set.nb_samples,
261                 nb_validation_errors,
262                 validation_set.nb_samples)
263             )
264
265             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
266                 log_string('below validation_error_threshold')
267                 break
268
269     return model
270
271 ######################################################################
272
273 for arg in vars(args):
274     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
275
276 ######################################################################
277
278 def int_to_suffix(n):
279     if n >= 1000000 and n%1000000 == 0:
280         return str(n//1000000) + 'M'
281     elif n >= 1000 and n%1000 == 0:
282         return str(n//1000) + 'K'
283     else:
284         return str(n)
285
286 class vignette_logger():
287     def __init__(self, delay_min = 60):
288         self.start_t = time.time()
289         self.last_t = self.start_t
290         self.delay_min = delay_min
291
292     def __call__(self, n, m):
293         t = time.time()
294         if t > self.last_t + self.delay_min:
295             dt = (t - self.start_t) / m
296             log_string('sample_generation {:d} / {:d}'.format(
297                 m,
298                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
299             )
300             self.last_t = t
301
302 def save_examplar_vignettes(data_set, nb, name):
303     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
304
305     for k in range(0, nb):
306         b = n[k] // data_set.batch_size
307         m = n[k] % data_set.batch_size
308         i, t = data_set.get_batch(b)
309         i = i[m].float()
310         i.sub_(i.min())
311         i.div_(i.max())
312         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
313         patchwork[k].copy_(i)
314
315     torchvision.utils.save_image(patchwork, name)
316
317 ######################################################################
318
319 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
320     print('The number of samples must be a multiple of the batch size.')
321     raise
322
323 log_string('############### start ###############')
324
325 if args.compress_vignettes:
326     log_string('using_compressed_vignettes')
327     VignetteSet = svrtset.CompressedVignetteSet
328 else:
329     log_string('using_uncompressed_vignettes')
330     VignetteSet = svrtset.VignetteSet
331
332 for problem_number in map(int, args.problems.split(',')):
333
334     log_string('############### problem ' + str(problem_number) + ' ###############')
335
336     if args.deep_model:
337         model = AfrozeDeepNet()
338     else:
339         model = AfrozeShallowNet()
340
341     if torch.cuda.is_available(): model.cuda()
342
343     model_filename = model.name + '_pb:' + \
344                      str(problem_number) + '_ns:' + \
345                      int_to_suffix(args.nb_train_samples) + '.param'
346
347     nb_parameters = 0
348     for p in model.parameters(): nb_parameters += p.numel()
349     log_string('nb_parameters {:d}'.format(nb_parameters))
350
351     ##################################################
352     # Tries to load the model
353
354     need_to_train = False
355     try:
356         model.load_state_dict(torch.load(model_filename))
357         log_string('loaded_model ' + model_filename)
358     except:
359         need_to_train = True
360
361     ##################################################
362     # Train if necessary
363
364     if need_to_train:
365
366         log_string('training_model ' + model_filename)
367
368         t = time.time()
369
370         train_set = VignetteSet(problem_number,
371                                 args.nb_train_samples, args.batch_size,
372                                 cuda = torch.cuda.is_available(),
373                                 logger = vignette_logger())
374
375         log_string('data_generation {:0.2f} samples / s'.format(
376             train_set.nb_samples / (time.time() - t))
377         )
378
379         if args.nb_exemplar_vignettes > 0:
380             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
381                                     'examplar_{:d}.png'.format(problem_number))
382
383         if args.validation_error_threshold > 0.0:
384             validation_set = VignetteSet(problem_number,
385                                          args.nb_validation_samples, args.batch_size,
386                                          cuda = torch.cuda.is_available(),
387                                          logger = vignette_logger())
388         else:
389             validation_set = None
390
391         train_model(model, train_set, validation_set)
392         torch.save(model.state_dict(), model_filename)
393         log_string('saved_model ' + model_filename)
394
395         nb_train_errors = nb_errors(model, train_set)
396
397         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
398             problem_number,
399             100 * nb_train_errors / train_set.nb_samples,
400             nb_train_errors,
401             train_set.nb_samples)
402         )
403
404     ##################################################
405     # Test if necessary
406
407     if need_to_train or args.test_loaded_models:
408
409         t = time.time()
410
411         test_set = VignetteSet(problem_number,
412                                args.nb_test_samples, args.batch_size,
413                                cuda = torch.cuda.is_available())
414
415         nb_test_errors = nb_errors(model, test_set)
416
417         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
418             problem_number,
419             100 * nb_test_errors / test_set.nb_samples,
420             nb_test_errors,
421             test_set.nb_samples)
422         )
423
424 ######################################################################