Minor changes in the logging.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35
36 from torch import optim
37 from torch import FloatTensor as Tensor
38 from torch.autograd import Variable
39 from torch import nn
40 from torch.nn import functional as fn
41 from torchvision import datasets, transforms, utils
42
43 # SVRT
44
45 import svrtset
46
47 ######################################################################
48
49 parser = argparse.ArgumentParser(
50     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
51     formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 )
53
54 parser.add_argument('--nb_train_samples',
55                     type = int, default = 100000)
56
57 parser.add_argument('--nb_test_samples',
58                     type = int, default = 10000)
59
60 parser.add_argument('--nb_validation_samples',
61                     type = int, default = 10000)
62
63 parser.add_argument('--validation_error_threshold',
64                     type = float, default = 0.0,
65                     help = 'Early training termination criterion')
66
67 parser.add_argument('--nb_epochs',
68                     type = int, default = 50)
69
70 parser.add_argument('--batch_size',
71                     type = int, default = 100)
72
73 parser.add_argument('--log_file',
74                     type = str, default = 'default.log')
75
76 parser.add_argument('--compress_vignettes',
77                     type = distutils.util.strtobool, default = 'True',
78                     help = 'Use lossless compression to reduce the memory footprint')
79
80 parser.add_argument('--deep_model',
81                     type = distutils.util.strtobool, default = 'True',
82                     help = 'Use Afroze\'s Alexnet-like deep model')
83
84 parser.add_argument('--test_loaded_models',
85                     type = distutils.util.strtobool, default = 'False',
86                     help = 'Should we compute the test errors of loaded models')
87
88 parser.add_argument('--problems',
89                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
90                     help = 'What problems to process')
91
92 args = parser.parse_args()
93
94 ######################################################################
95
96 log_file = open(args.log_file, 'a')
97 pred_log_t = None
98 last_tag_t = time.time()
99
100 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
101
102 # Log and prints the string, with a time stamp. Does not log the
103 # remark
104
105 def log_string(s, remark = ''):
106     global pred_log_t, last_tag_t
107
108     t = time.time()
109
110     if pred_log_t is None:
111         elapsed = 'start'
112     else:
113         elapsed = '+{:.02f}s'.format(t - pred_log_t)
114
115     pred_log_t = t
116
117     if t > last_tag_t + 3600:
118         last_tag_t = t
119         print(Fore.RED + time.ctime() + Style.RESET_ALL)
120
121     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
122     log_file.flush()
123
124     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
125
126 ######################################################################
127
128 # Afroze's ShallowNet
129
130 #                       map size   nb. maps
131 #                     ----------------------
132 #    input                128x128    1
133 # -- conv(21x21 x 6)   -> 108x108    6
134 # -- max(2x2)          -> 54x54      6
135 # -- conv(19x19 x 16)  -> 36x36      16
136 # -- max(2x2)          -> 18x18      16
137 # -- conv(18x18 x 120) -> 1x1        120
138 # -- reshape           -> 120        1
139 # -- full(120x84)      -> 84         1
140 # -- full(84x2)        -> 2          1
141
142 class AfrozeShallowNet(nn.Module):
143     def __init__(self):
144         super(AfrozeShallowNet, self).__init__()
145         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
146         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
147         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
148         self.fc1 = nn.Linear(120, 84)
149         self.fc2 = nn.Linear(84, 2)
150         self.name = 'shallownet'
151
152     def forward(self, x):
153         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
154         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
155         x = fn.relu(self.conv3(x))
156         x = x.view(-1, 120)
157         x = fn.relu(self.fc1(x))
158         x = self.fc2(x)
159         return x
160
161 ######################################################################
162
163 # Afroze's DeepNet
164
165 class AfrozeDeepNet(nn.Module):
166     def __init__(self):
167         super(AfrozeDeepNet, self).__init__()
168         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
169         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
170         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
171         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
172         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
173         self.fc1 = nn.Linear(1536, 256)
174         self.fc2 = nn.Linear(256, 256)
175         self.fc3 = nn.Linear(256, 2)
176         self.name = 'deepnet'
177
178     def forward(self, x):
179         x = self.conv1(x)
180         x = fn.max_pool2d(x, kernel_size=2)
181         x = fn.relu(x)
182
183         x = self.conv2(x)
184         x = fn.max_pool2d(x, kernel_size=2)
185         x = fn.relu(x)
186
187         x = self.conv3(x)
188         x = fn.relu(x)
189
190         x = self.conv4(x)
191         x = fn.relu(x)
192
193         x = self.conv5(x)
194         x = fn.max_pool2d(x, kernel_size=2)
195         x = fn.relu(x)
196
197         x = x.view(-1, 1536)
198
199         x = self.fc1(x)
200         x = fn.relu(x)
201
202         x = self.fc2(x)
203         x = fn.relu(x)
204
205         x = self.fc3(x)
206
207         return x
208
209 ######################################################################
210
211 def nb_errors(model, data_set):
212     ne = 0
213     for b in range(0, data_set.nb_batches):
214         input, target = data_set.get_batch(b)
215         output = model.forward(Variable(input))
216         wta_prediction = output.data.max(1)[1].view(-1)
217
218         for i in range(0, data_set.batch_size):
219             if wta_prediction[i] != target[i]:
220                 ne = ne + 1
221
222     return ne
223
224 ######################################################################
225
226 def train_model(model, train_set, validation_set):
227     batch_size = args.batch_size
228     criterion = nn.CrossEntropyLoss()
229
230     if torch.cuda.is_available():
231         criterion.cuda()
232
233     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
234
235     start_t = time.time()
236
237     for e in range(0, args.nb_epochs):
238         acc_loss = 0.0
239         for b in range(0, train_set.nb_batches):
240             input, target = train_set.get_batch(b)
241             output = model.forward(Variable(input))
242             loss = criterion(output, Variable(target))
243             acc_loss = acc_loss + loss.data[0]
244             model.zero_grad()
245             loss.backward()
246             optimizer.step()
247         dt = (time.time() - start_t) / (e + 1)
248
249         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
250                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
251
252         if validation_set is not None:
253             nb_validation_errors = nb_errors(model, validation_set)
254
255             log_string('validation_error {:.02f}% {:d} {:d}'.format(
256                 100 * nb_validation_errors / validation_set.nb_samples,
257                 nb_validation_errors,
258                 validation_set.nb_samples)
259             )
260
261             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
262                 log_string('below validation_error_threshold')
263                 break
264
265     return model
266
267 ######################################################################
268
269 for arg in vars(args):
270     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
271
272 ######################################################################
273
274 def int_to_suffix(n):
275     if n >= 1000000 and n%1000000 == 0:
276         return str(n//1000000) + 'M'
277     elif n >= 1000 and n%1000 == 0:
278         return str(n//1000) + 'K'
279     else:
280         return str(n)
281
282 class vignette_logger():
283     def __init__(self, delay_min = 60):
284         self.start_t = time.time()
285         self.last_t = self.start_t
286         self.delay_min = delay_min
287
288     def __call__(self, n, m):
289         t = time.time()
290         if t > self.last_t + self.delay_min:
291             dt = (t - self.start_t) / m
292             log_string('sample_generation {:d} / {:d}'.format(
293                 m,
294                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
295             )
296             self.last_t = t
297
298 ######################################################################
299
300 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
301     print('The number of samples must be a multiple of the batch size.')
302     raise
303
304 log_string('############### start ###############')
305
306 if args.compress_vignettes:
307     log_string('using_compressed_vignettes')
308     VignetteSet = svrtset.CompressedVignetteSet
309 else:
310     log_string('using_uncompressed_vignettes')
311     VignetteSet = svrtset.VignetteSet
312
313 for problem_number in map(int, args.problems.split(',')):
314
315     log_string('############### problem ' + str(problem_number) + ' ###############')
316
317     if args.deep_model:
318         model = AfrozeDeepNet()
319     else:
320         model = AfrozeShallowNet()
321
322     if torch.cuda.is_available(): model.cuda()
323
324     model_filename = model.name + '_pb:' + \
325                      str(problem_number) + '_ns:' + \
326                      int_to_suffix(args.nb_train_samples) + '.param'
327
328     nb_parameters = 0
329     for p in model.parameters(): nb_parameters += p.numel()
330     log_string('nb_parameters {:d}'.format(nb_parameters))
331
332     ##################################################
333     # Tries to load the model
334
335     need_to_train = False
336     try:
337         model.load_state_dict(torch.load(model_filename))
338         log_string('loaded_model ' + model_filename)
339     except:
340         need_to_train = True
341
342     ##################################################
343     # Train if necessary
344
345     if need_to_train:
346
347         log_string('training_model ' + model_filename)
348
349         t = time.time()
350
351         train_set = VignetteSet(problem_number,
352                                 args.nb_train_samples, args.batch_size,
353                                 cuda = torch.cuda.is_available(),
354                                 logger = vignette_logger())
355
356         log_string('data_generation {:0.2f} samples / s'.format(
357             train_set.nb_samples / (time.time() - t))
358         )
359
360         if args.validation_error_threshold > 0.0:
361             validation_set = VignetteSet(problem_number,
362                                          args.nb_validation_samples, args.batch_size,
363                                          cuda = torch.cuda.is_available(),
364                                          logger = vignette_logger())
365         else:
366             validation_set = None
367
368         train_model(model, train_set, validation_set)
369         torch.save(model.state_dict(), model_filename)
370         log_string('saved_model ' + model_filename)
371
372         nb_train_errors = nb_errors(model, train_set)
373
374         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
375             problem_number,
376             100 * nb_train_errors / train_set.nb_samples,
377             nb_train_errors,
378             train_set.nb_samples)
379         )
380
381     ##################################################
382     # Test if necessary
383
384     if need_to_train or args.test_loaded_models:
385
386         t = time.time()
387
388         test_set = VignetteSet(problem_number,
389                                args.nb_test_samples, args.batch_size,
390                                cuda = torch.cuda.is_available())
391
392         nb_test_errors = nb_errors(model, test_set)
393
394         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
395             problem_number,
396             100 * nb_test_errors / test_set.nb_samples,
397             nb_test_errors,
398             test_set.nb_samples)
399         )
400
401 ######################################################################