7fde85d673f3e95e0f3403bb48384b3c7872ac3a
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28
29 from colorama import Fore, Back, Style
30
31 # Pytorch
32
33 import torch
34
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
38 from torch import nn
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
41
42 # SVRT
43
44 import svrtset
45
46 ######################################################################
47
48 parser = argparse.ArgumentParser(
49     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50     formatter_class = argparse.ArgumentDefaultsHelpFormatter
51 )
52
53 parser.add_argument('--nb_train_samples',
54                     type = int, default = 100000)
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000)
58
59 parser.add_argument('--nb_epochs',
60                     type = int, default = 50)
61
62 parser.add_argument('--batch_size',
63                     type = int, default = 100)
64
65 parser.add_argument('--log_file',
66                     type = str, default = 'default.log')
67
68 parser.add_argument('--compress_vignettes',
69                     type = distutils.util.strtobool, default = 'True',
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 parser.add_argument('--deep_model',
73                     type = distutils.util.strtobool, default = 'True',
74                     help = 'Use Afroze\'s Alexnet-like deep model')
75
76 parser.add_argument('--test_loaded_models',
77                     type = distutils.util.strtobool, default = 'False',
78                     help = 'Should we compute the test errors of loaded models')
79
80 parser.add_argument('--problems',
81                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
82                     help = 'What problem to process')
83
84 args = parser.parse_args()
85
86 ######################################################################
87
88 log_file = open(args.log_file, 'w')
89 pred_log_t = None
90
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
92
93 # Log and prints the string, with a time stamp. Does not log the
94 # remark
95 def log_string(s, remark = ''):
96     global pred_log_t
97
98     t = time.time()
99
100     if pred_log_t is None:
101         elapsed = 'start'
102     else:
103         elapsed = '+{:.02f}s'.format(t - pred_log_t)
104
105     pred_log_t = t
106
107     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
108     log_file.flush()
109
110     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
111
112 ######################################################################
113
114 # Afroze's ShallowNet
115
116 #                       map size   nb. maps
117 #                     ----------------------
118 #    input                128x128    1
119 # -- conv(21x21 x 6)   -> 108x108    6
120 # -- max(2x2)          -> 54x54      6
121 # -- conv(19x19 x 16)  -> 36x36      16
122 # -- max(2x2)          -> 18x18      16
123 # -- conv(18x18 x 120) -> 1x1        120
124 # -- reshape           -> 120        1
125 # -- full(120x84)      -> 84         1
126 # -- full(84x2)        -> 2          1
127
128 class AfrozeShallowNet(nn.Module):
129     def __init__(self):
130         super(AfrozeShallowNet, self).__init__()
131         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134         self.fc1 = nn.Linear(120, 84)
135         self.fc2 = nn.Linear(84, 2)
136         self.name = 'shallownet'
137
138     def forward(self, x):
139         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141         x = fn.relu(self.conv3(x))
142         x = x.view(-1, 120)
143         x = fn.relu(self.fc1(x))
144         x = self.fc2(x)
145         return x
146
147 ######################################################################
148
149 # Afroze's DeepNet
150
151 class AfrozeDeepNet(nn.Module):
152     def __init__(self):
153         super(AfrozeDeepNet, self).__init__()
154         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
155         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
156         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
157         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
158         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
159         self.fc1 = nn.Linear(1536, 256)
160         self.fc2 = nn.Linear(256, 256)
161         self.fc3 = nn.Linear(256, 2)
162         self.name = 'deepnet'
163
164     def forward(self, x):
165         x = self.conv1(x)
166         x = fn.max_pool2d(x, kernel_size=2)
167         x = fn.relu(x)
168
169         x = self.conv2(x)
170         x = fn.max_pool2d(x, kernel_size=2)
171         x = fn.relu(x)
172
173         x = self.conv3(x)
174         x = fn.relu(x)
175
176         x = self.conv4(x)
177         x = fn.relu(x)
178
179         x = self.conv5(x)
180         x = fn.max_pool2d(x, kernel_size=2)
181         x = fn.relu(x)
182
183         x = x.view(-1, 1536)
184
185         x = self.fc1(x)
186         x = fn.relu(x)
187
188         x = self.fc2(x)
189         x = fn.relu(x)
190
191         x = self.fc3(x)
192
193         return x
194
195 ######################################################################
196
197 def train_model(model, train_set):
198     batch_size = args.batch_size
199     criterion = nn.CrossEntropyLoss()
200
201     if torch.cuda.is_available():
202         criterion.cuda()
203
204     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
205
206     start_t = time.time()
207
208     for e in range(0, args.nb_epochs):
209         acc_loss = 0.0
210         for b in range(0, train_set.nb_batches):
211             input, target = train_set.get_batch(b)
212             output = model.forward(Variable(input))
213             loss = criterion(output, Variable(target))
214             acc_loss = acc_loss + loss.data[0]
215             model.zero_grad()
216             loss.backward()
217             optimizer.step()
218         dt = (time.time() - start_t) / (e + 1)
219         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
220                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
221
222     return model
223
224 ######################################################################
225
226 def nb_errors(model, data_set):
227     ne = 0
228     for b in range(0, data_set.nb_batches):
229         input, target = data_set.get_batch(b)
230         output = model.forward(Variable(input))
231         wta_prediction = output.data.max(1)[1].view(-1)
232
233         for i in range(0, data_set.batch_size):
234             if wta_prediction[i] != target[i]:
235                 ne = ne + 1
236
237     return ne
238
239 ######################################################################
240
241 for arg in vars(args):
242     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
243
244 ######################################################################
245
246 def int_to_suffix(n):
247     if n >= 1000000 and n%1000000 == 0:
248         return str(n//1000000) + 'M'
249     elif n >= 1000 and n%1000 == 0:
250         return str(n//1000) + 'K'
251     else:
252         return str(n)
253
254 class vignette_logger():
255     def __init__(self, delay_min = 60):
256         self.start_t = time.time()
257         self.last_t = self.start_t
258         self.delay_min = delay_min
259
260     def __call__(self, n, m):
261         t = time.time()
262         if t > self.last_t + self.delay_min:
263             dt = (t - self.start_t) / m
264             log_string('sample_generation {:d} / {:d}'.format(
265                 m,
266                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
267             )
268             self.last_t = t
269
270 ######################################################################
271
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273     print('The number of samples must be a multiple of the batch size.')
274     raise
275
276 if args.compress_vignettes:
277     log_string('using_compressed_vignettes')
278     VignetteSet = svrtset.CompressedVignetteSet
279 else:
280     log_string('using_uncompressed_vignettes')
281     VignetteSet = svrtset.VignetteSet
282
283 for problem_number in map(int, args.problems.split(',')):
284
285     log_string('############### problem ' + str(problem_number) + ' ###############')
286
287     if args.deep_model:
288         model = AfrozeDeepNet()
289     else:
290         model = AfrozeShallowNet()
291
292     if torch.cuda.is_available(): model.cuda()
293
294     model_filename = model.name + '_pb:' + \
295                      str(problem_number) + '_ns:' + \
296                      int_to_suffix(args.nb_train_samples) + '.param'
297
298     nb_parameters = 0
299     for p in model.parameters(): nb_parameters += p.numel()
300     log_string('nb_parameters {:d}'.format(nb_parameters))
301
302     ##################################################
303     # Tries to load the model
304
305     need_to_train = False
306     try:
307         model.load_state_dict(torch.load(model_filename))
308         log_string('loaded_model ' + model_filename)
309     except:
310         need_to_train = True
311
312     ##################################################
313     # Train if necessary
314
315     if need_to_train:
316
317         log_string('training_model ' + model_filename)
318
319         t = time.time()
320
321         train_set = VignetteSet(problem_number,
322                                 args.nb_train_samples, args.batch_size,
323                                 cuda = torch.cuda.is_available(),
324                                 logger = vignette_logger())
325
326         log_string('data_generation {:0.2f} samples / s'.format(
327             train_set.nb_samples / (time.time() - t))
328         )
329
330         train_model(model, train_set)
331         torch.save(model.state_dict(), model_filename)
332         log_string('saved_model ' + model_filename)
333
334         nb_train_errors = nb_errors(model, train_set)
335
336         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
337             problem_number,
338             100 * nb_train_errors / train_set.nb_samples,
339             nb_train_errors,
340             train_set.nb_samples)
341         )
342
343     ##################################################
344     # Test if necessary
345
346     if need_to_train or args.test_loaded_models:
347
348         t = time.time()
349
350         test_set = VignetteSet(problem_number,
351                                args.nb_test_samples, args.batch_size,
352                                cuda = torch.cuda.is_available())
353
354         log_string('data_generation {:0.2f} samples / s'.format(
355             test_set.nb_samples / (time.time() - t))
356         )
357
358         nb_test_errors = nb_errors(model, test_set)
359
360         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
361             problem_number,
362             100 * nb_test_errors / test_set.nb_samples,
363             nb_test_errors,
364             test_set.nb_samples)
365         )
366
367 ######################################################################