Cosmetics.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_samples',
53                     type = int, default = 100000)
54
55 parser.add_argument('--nb_test_samples',
56                     type = int, default = 10000)
57
58 parser.add_argument('--nb_epochs',
59                     type = int, default = 50)
60
61 parser.add_argument('--batch_size',
62                     type = int, default = 100)
63
64 parser.add_argument('--log_file',
65                     type = str, default = 'default.log')
66
67 parser.add_argument('--compress_vignettes',
68                     action='store_true', default = True,
69                     help = 'Use lossless compression to reduce the memory footprint')
70
71 parser.add_argument('--deep_model',
72                     action='store_true', default = True,
73                     help = 'Use Afroze\'s Alexnet-like deep model')
74
75 parser.add_argument('--test_loaded_models',
76                     action='store_true', default = False,
77                     help = 'Should we compute the test errors of loaded models')
78
79 args = parser.parse_args()
80
81 ######################################################################
82
83 log_file = open(args.log_file, 'w')
84 pred_log_t = None
85
86 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
87
88 # Log and prints the string, with a time stamp. Does not log the
89 # remark
90 def log_string(s, remark = ''):
91     global pred_log_t
92
93     t = time.time()
94
95     if pred_log_t is None:
96         elapsed = 'start'
97     else:
98         elapsed = '+{:.02f}s'.format(t - pred_log_t)
99
100     pred_log_t = t
101
102     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
103     log_file.flush()
104
105     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
106
107 ######################################################################
108
109 # Afroze's ShallowNet
110
111 #                       map size   nb. maps
112 #                     ----------------------
113 #    input                128x128    1
114 # -- conv(21x21 x 6)   -> 108x108    6
115 # -- max(2x2)          -> 54x54      6
116 # -- conv(19x19 x 16)  -> 36x36      16
117 # -- max(2x2)          -> 18x18      16
118 # -- conv(18x18 x 120) -> 1x1        120
119 # -- reshape           -> 120        1
120 # -- full(120x84)      -> 84         1
121 # -- full(84x2)        -> 2          1
122
123 class AfrozeShallowNet(nn.Module):
124     def __init__(self):
125         super(AfrozeShallowNet, self).__init__()
126         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
127         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
128         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
129         self.fc1 = nn.Linear(120, 84)
130         self.fc2 = nn.Linear(84, 2)
131         self.name = 'shallownet'
132
133     def forward(self, x):
134         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
135         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
136         x = fn.relu(self.conv3(x))
137         x = x.view(-1, 120)
138         x = fn.relu(self.fc1(x))
139         x = self.fc2(x)
140         return x
141
142 ######################################################################
143
144 # Afroze's DeepNet
145
146 #                       map size   nb. maps
147 #                     ----------------------
148 #    input                128x128    1
149 # -- conv(21x21 x 32 stride=4) -> 28x28    32
150 # -- max(2x2)                  -> 14x14      6
151 # -- conv(7x7 x 96)            -> 8x8      16
152 # -- max(2x2)                  -> 4x4      16
153 # -- conv(5x5 x 96)            -> 26x36      16
154 # -- conv(3x3 x 128)           -> 36x36      16
155 # -- conv(3x3 x 128)           -> 36x36      16
156
157 # -- conv(5x5 x 120) -> 1x1        120
158 # -- reshape           -> 120        1
159 # -- full(3x84)      -> 84         1
160 # -- full(84x2)        -> 2          1
161
162 class AfrozeDeepNet(nn.Module):
163     def __init__(self):
164         super(AfrozeDeepNet, self).__init__()
165         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
166         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
167         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
168         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
169         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
170         self.fc1 = nn.Linear(1536, 256)
171         self.fc2 = nn.Linear(256, 256)
172         self.fc3 = nn.Linear(256, 2)
173         self.name = 'deepnet'
174
175     def forward(self, x):
176         x = self.conv1(x)
177         x = fn.max_pool2d(x, kernel_size=2)
178         x = fn.relu(x)
179
180         x = self.conv2(x)
181         x = fn.max_pool2d(x, kernel_size=2)
182         x = fn.relu(x)
183
184         x = self.conv3(x)
185         x = fn.relu(x)
186
187         x = self.conv4(x)
188         x = fn.relu(x)
189
190         x = self.conv5(x)
191         x = fn.max_pool2d(x, kernel_size=2)
192         x = fn.relu(x)
193
194         x = x.view(-1, 1536)
195
196         x = self.fc1(x)
197         x = fn.relu(x)
198
199         x = self.fc2(x)
200         x = fn.relu(x)
201
202         x = self.fc3(x)
203
204         return x
205
206 ######################################################################
207
208 def train_model(model, train_set):
209     batch_size = args.batch_size
210     criterion = nn.CrossEntropyLoss()
211
212     if torch.cuda.is_available():
213         criterion.cuda()
214
215     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
216
217     start_t = time.time()
218
219     for e in range(0, args.nb_epochs):
220         acc_loss = 0.0
221         for b in range(0, train_set.nb_batches):
222             input, target = train_set.get_batch(b)
223             output = model.forward(Variable(input))
224             loss = criterion(output, Variable(target))
225             acc_loss = acc_loss + loss.data[0]
226             model.zero_grad()
227             loss.backward()
228             optimizer.step()
229         dt = (time.time() - start_t) / (e + 1)
230         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
231                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
232
233     return model
234
235 ######################################################################
236
237 def nb_errors(model, data_set):
238     ne = 0
239     for b in range(0, data_set.nb_batches):
240         input, target = data_set.get_batch(b)
241         output = model.forward(Variable(input))
242         wta_prediction = output.data.max(1)[1].view(-1)
243
244         for i in range(0, data_set.batch_size):
245             if wta_prediction[i] != target[i]:
246                 ne = ne + 1
247
248     return ne
249
250 ######################################################################
251
252 for arg in vars(args):
253     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
254
255 ######################################################################
256
257 def int_to_suffix(n):
258     if n > 1000000 and n%1000000 == 0:
259         return str(n//1000000) + 'M'
260     elif n > 1000 and n%1000 == 0:
261         return str(n//1000) + 'K'
262     else:
263         return str(n)
264
265 ######################################################################
266
267 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
268     print('The number of samples must be a multiple of the batch size.')
269     raise
270
271 for problem_number in range(1, 24):
272
273     log_string('**** problem ' + str(problem_number) + ' ****')
274
275     if args.deep_model:
276         model = AfrozeDeepNet()
277     else:
278         model = AfrozeShallowNet()
279
280     if torch.cuda.is_available():
281         model.cuda()
282
283     model_filename = model.name + '_' + \
284                      str(problem_number) + '_' + \
285                      int_to_suffix(args.nb_train_samples) + '.param'
286
287     nb_parameters = 0
288     for p in model.parameters(): nb_parameters += p.numel()
289     log_string('nb_parameters {:d}'.format(nb_parameters))
290
291     need_to_train = False
292     try:
293         model.load_state_dict(torch.load(model_filename))
294         log_string('loaded_model ' + model_filename)
295     except:
296         need_to_train = True
297
298     if need_to_train:
299
300         log_string('training_model ' + model_filename)
301
302         t = time.time()
303
304         if args.compress_vignettes:
305             train_set = CompressedVignetteSet(problem_number,
306                                               args.nb_train_samples, args.batch_size,
307                                               cuda = torch.cuda.is_available())
308         else:
309             train_set = VignetteSet(problem_number,
310                                     args.nb_train_samples, args.batch_size,
311                                     cuda = torch.cuda.is_available())
312
313         log_string('data_generation {:0.2f} samples / s'.format(
314             train_set.nb_samples / (time.time() - t))
315         )
316
317         train_model(model, train_set)
318         torch.save(model.state_dict(), model_filename)
319         log_string('saved_model ' + model_filename)
320
321         nb_train_errors = nb_errors(model, train_set)
322
323         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
324             problem_number,
325             100 * nb_train_errors / train_set.nb_samples,
326             nb_train_errors,
327             train_set.nb_samples)
328         )
329
330     if need_to_train or args.test_loaded_models:
331
332         t = time.time()
333
334         if args.compress_vignettes:
335             test_set = CompressedVignetteSet(problem_number,
336                                              args.nb_test_samples, args.batch_size,
337                                              cuda = torch.cuda.is_available())
338         else:
339             test_set = VignetteSet(problem_number,
340                                    args.nb_test_samples, args.batch_size,
341                                    cuda = torch.cuda.is_available())
342
343         log_string('data_generation {:0.2f} samples / s'.format(
344             test_set.nb_samples / (time.time() - t))
345         )
346
347         nb_test_errors = nb_errors(model, test_set)
348
349         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
350             problem_number,
351             100 * nb_test_errors / test_set.nb_samples,
352             nb_test_errors,
353             test_set.nb_samples)
354         )
355
356 ######################################################################