0216b4e5cbba5740e053d43d5d9fad5ff0ab2f27
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28
29 from colorama import Fore, Back, Style
30
31 # Pytorch
32
33 import torch
34
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
38 from torch import nn
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
41
42 # SVRT
43
44 import vignette_set
45
46 ######################################################################
47
48 parser = argparse.ArgumentParser(
49     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50     formatter_class = argparse.ArgumentDefaultsHelpFormatter
51 )
52
53 parser.add_argument('--nb_train_samples',
54                     type = int, default = 100000)
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000)
58
59 parser.add_argument('--nb_epochs',
60                     type = int, default = 50)
61
62 parser.add_argument('--batch_size',
63                     type = int, default = 100)
64
65 parser.add_argument('--log_file',
66                     type = str, default = 'default.log')
67
68 parser.add_argument('--compress_vignettes',
69                     type = distutils.util.strtobool, default = 'True',
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 parser.add_argument('--deep_model',
73                     type = distutils.util.strtobool, default = 'True',
74                     help = 'Use Afroze\'s Alexnet-like deep model')
75
76 parser.add_argument('--test_loaded_models',
77                     type = distutils.util.strtobool, default = 'False',
78                     help = 'Should we compute the test errors of loaded models')
79
80 args = parser.parse_args()
81
82 ######################################################################
83
84 log_file = open(args.log_file, 'w')
85 pred_log_t = None
86
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88
89 # Log and prints the string, with a time stamp. Does not log the
90 # remark
91 def log_string(s, remark = ''):
92     global pred_log_t
93
94     t = time.time()
95
96     if pred_log_t is None:
97         elapsed = 'start'
98     else:
99         elapsed = '+{:.02f}s'.format(t - pred_log_t)
100
101     pred_log_t = t
102
103     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
104     log_file.flush()
105
106     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
107
108 ######################################################################
109
110 # Afroze's ShallowNet
111
112 #                       map size   nb. maps
113 #                     ----------------------
114 #    input                128x128    1
115 # -- conv(21x21 x 6)   -> 108x108    6
116 # -- max(2x2)          -> 54x54      6
117 # -- conv(19x19 x 16)  -> 36x36      16
118 # -- max(2x2)          -> 18x18      16
119 # -- conv(18x18 x 120) -> 1x1        120
120 # -- reshape           -> 120        1
121 # -- full(120x84)      -> 84         1
122 # -- full(84x2)        -> 2          1
123
124 class AfrozeShallowNet(nn.Module):
125     def __init__(self):
126         super(AfrozeShallowNet, self).__init__()
127         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130         self.fc1 = nn.Linear(120, 84)
131         self.fc2 = nn.Linear(84, 2)
132         self.name = 'shallownet'
133
134     def forward(self, x):
135         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137         x = fn.relu(self.conv3(x))
138         x = x.view(-1, 120)
139         x = fn.relu(self.fc1(x))
140         x = self.fc2(x)
141         return x
142
143 ######################################################################
144
145 # Afroze's DeepNet
146
147 #                       map size   nb. maps
148 #                     ----------------------
149 #    input                128x128    1
150 # -- conv(21x21 x 32 stride=4) -> 28x28    32
151 # -- max(2x2)                  -> 14x14      6
152 # -- conv(7x7 x 96)            -> 8x8      16
153 # -- max(2x2)                  -> 4x4      16
154 # -- conv(5x5 x 96)            -> 26x36      16
155 # -- conv(3x3 x 128)           -> 36x36      16
156 # -- conv(3x3 x 128)           -> 36x36      16
157
158 # -- conv(5x5 x 120) -> 1x1        120
159 # -- reshape           -> 120        1
160 # -- full(3x84)      -> 84         1
161 # -- full(84x2)        -> 2          1
162
163 class AfrozeDeepNet(nn.Module):
164     def __init__(self):
165         super(AfrozeDeepNet, self).__init__()
166         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
167         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
168         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
169         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
170         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
171         self.fc1 = nn.Linear(1536, 256)
172         self.fc2 = nn.Linear(256, 256)
173         self.fc3 = nn.Linear(256, 2)
174         self.name = 'deepnet'
175
176     def forward(self, x):
177         x = self.conv1(x)
178         x = fn.max_pool2d(x, kernel_size=2)
179         x = fn.relu(x)
180
181         x = self.conv2(x)
182         x = fn.max_pool2d(x, kernel_size=2)
183         x = fn.relu(x)
184
185         x = self.conv3(x)
186         x = fn.relu(x)
187
188         x = self.conv4(x)
189         x = fn.relu(x)
190
191         x = self.conv5(x)
192         x = fn.max_pool2d(x, kernel_size=2)
193         x = fn.relu(x)
194
195         x = x.view(-1, 1536)
196
197         x = self.fc1(x)
198         x = fn.relu(x)
199
200         x = self.fc2(x)
201         x = fn.relu(x)
202
203         x = self.fc3(x)
204
205         return x
206
207 ######################################################################
208
209 def train_model(model, train_set):
210     batch_size = args.batch_size
211     criterion = nn.CrossEntropyLoss()
212
213     if torch.cuda.is_available():
214         criterion.cuda()
215
216     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
217
218     start_t = time.time()
219
220     for e in range(0, args.nb_epochs):
221         acc_loss = 0.0
222         for b in range(0, train_set.nb_batches):
223             input, target = train_set.get_batch(b)
224             output = model.forward(Variable(input))
225             loss = criterion(output, Variable(target))
226             acc_loss = acc_loss + loss.data[0]
227             model.zero_grad()
228             loss.backward()
229             optimizer.step()
230         dt = (time.time() - start_t) / (e + 1)
231         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
232                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
233
234     return model
235
236 ######################################################################
237
238 def nb_errors(model, data_set):
239     ne = 0
240     for b in range(0, data_set.nb_batches):
241         input, target = data_set.get_batch(b)
242         output = model.forward(Variable(input))
243         wta_prediction = output.data.max(1)[1].view(-1)
244
245         for i in range(0, data_set.batch_size):
246             if wta_prediction[i] != target[i]:
247                 ne = ne + 1
248
249     return ne
250
251 ######################################################################
252
253 for arg in vars(args):
254     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
255
256 ######################################################################
257
258 def int_to_suffix(n):
259     if n >= 1000000 and n%1000000 == 0:
260         return str(n//1000000) + 'M'
261     elif n >= 1000 and n%1000 == 0:
262         return str(n//1000) + 'K'
263     else:
264         return str(n)
265
266 ######################################################################
267
268 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
269     print('The number of samples must be a multiple of the batch size.')
270     raise
271
272 if args.compress_vignettes:
273     log_string('using_compressed_vignettes')
274     VignetteSet = vignette_set.CompressedVignetteSet
275 else:
276     log_string('using_uncompressed_vignettes')
277     VignetteSet = vignette_set.VignetteSet
278
279 for problem_number in range(1, 24):
280
281     log_string('############### problem ' + str(problem_number) + ' ###############')
282
283     if args.deep_model:
284         model = AfrozeDeepNet()
285     else:
286         model = AfrozeShallowNet()
287
288     if torch.cuda.is_available(): model.cuda()
289
290     model_filename = model.name + '_pb:' + \
291                      str(problem_number) + '_ns:' + \
292                      int_to_suffix(args.nb_train_samples) + '.param'
293
294     nb_parameters = 0
295     for p in model.parameters(): nb_parameters += p.numel()
296     log_string('nb_parameters {:d}'.format(nb_parameters))
297
298     ##################################################
299     # Tries to load the model
300
301     need_to_train = False
302     try:
303         model.load_state_dict(torch.load(model_filename))
304         log_string('loaded_model ' + model_filename)
305     except:
306         need_to_train = True
307
308     ##################################################
309     # Train if necessary
310
311     if need_to_train:
312
313         log_string('training_model ' + model_filename)
314
315         t = time.time()
316
317         train_set = VignetteSet(problem_number,
318                                 args.nb_train_samples, args.batch_size,
319                                 cuda = torch.cuda.is_available())
320
321         log_string('data_generation {:0.2f} samples / s'.format(
322             train_set.nb_samples / (time.time() - t))
323         )
324
325         train_model(model, train_set)
326         torch.save(model.state_dict(), model_filename)
327         log_string('saved_model ' + model_filename)
328
329         nb_train_errors = nb_errors(model, train_set)
330
331         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
332             problem_number,
333             100 * nb_train_errors / train_set.nb_samples,
334             nb_train_errors,
335             train_set.nb_samples)
336         )
337
338     ##################################################
339     # Test if necessary
340
341     if need_to_train or args.test_loaded_models:
342
343         t = time.time()
344
345         test_set = VignetteSet(problem_number,
346                                args.nb_test_samples, args.batch_size,
347                                cuda = torch.cuda.is_available())
348
349         log_string('data_generation {:0.2f} samples / s'.format(
350             test_set.nb_samples / (time.time() - t))
351         )
352
353         nb_test_errors = nb_errors(model, test_set)
354
355         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
356             problem_number,
357             100 * nb_test_errors / test_set.nb_samples,
358             nb_test_errors,
359             test_set.nb_samples)
360         )
361
362 ######################################################################