3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_samples',
53 type = int, default = 100000)
55 parser.add_argument('--nb_test_samples',
56 type = int, default = 10000)
58 parser.add_argument('--nb_epochs',
59 type = int, default = 50)
61 parser.add_argument('--batch_size',
62 type = int, default = 100)
64 parser.add_argument('--log_file',
65 type = str, default = 'default.log')
67 parser.add_argument('--compress_vignettes',
68 action='store_true', default = True,
69 help = 'Use lossless compression to reduce the memory footprint')
71 parser.add_argument('--deep_model',
72 action='store_true', default = True,
73 help = 'Use Afroze\'s Alexnet-like deep model')
75 parser.add_argument('--test_loaded_models',
76 action='store_true', default = False,
77 help = 'Should we compute the test errors of loaded models')
79 args = parser.parse_args()
81 ######################################################################
83 log_file = open(args.log_file, 'w')
86 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88 # Log and prints the string, with a time stamp. Does not log the
90 def log_string(s, remark = ''):
95 if pred_log_t is None:
98 elapsed = '+{:.02f}s'.format(t - pred_log_t)
102 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
105 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
107 ######################################################################
109 # Afroze's ShallowNet
112 # ----------------------
114 # -- conv(21x21 x 6) -> 108x108 6
115 # -- max(2x2) -> 54x54 6
116 # -- conv(19x19 x 16) -> 36x36 16
117 # -- max(2x2) -> 18x18 16
118 # -- conv(18x18 x 120) -> 1x1 120
119 # -- reshape -> 120 1
120 # -- full(120x84) -> 84 1
121 # -- full(84x2) -> 2 1
123 class AfrozeShallowNet(nn.Module):
125 super(AfrozeShallowNet, self).__init__()
126 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
127 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
128 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
129 self.fc1 = nn.Linear(120, 84)
130 self.fc2 = nn.Linear(84, 2)
131 self.name = 'shallownet'
133 def forward(self, x):
134 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
135 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
136 x = fn.relu(self.conv3(x))
138 x = fn.relu(self.fc1(x))
142 ######################################################################
147 # ----------------------
149 # -- conv(21x21 x 32 stride=4) -> 28x28 32
150 # -- max(2x2) -> 14x14 6
151 # -- conv(7x7 x 96) -> 8x8 16
152 # -- max(2x2) -> 4x4 16
153 # -- conv(5x5 x 96) -> 26x36 16
154 # -- conv(3x3 x 128) -> 36x36 16
155 # -- conv(3x3 x 128) -> 36x36 16
157 # -- conv(5x5 x 120) -> 1x1 120
158 # -- reshape -> 120 1
159 # -- full(3x84) -> 84 1
160 # -- full(84x2) -> 2 1
162 class AfrozeDeepNet(nn.Module):
164 super(AfrozeDeepNet, self).__init__()
165 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
166 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
167 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
168 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
169 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
170 self.fc1 = nn.Linear(1536, 256)
171 self.fc2 = nn.Linear(256, 256)
172 self.fc3 = nn.Linear(256, 2)
173 self.name = 'deepnet'
175 def forward(self, x):
177 x = fn.max_pool2d(x, kernel_size=2)
181 x = fn.max_pool2d(x, kernel_size=2)
191 x = fn.max_pool2d(x, kernel_size=2)
206 ######################################################################
208 def train_model(model, train_set):
209 batch_size = args.batch_size
210 criterion = nn.CrossEntropyLoss()
212 if torch.cuda.is_available():
215 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
217 start_t = time.time()
219 for e in range(0, args.nb_epochs):
221 for b in range(0, train_set.nb_batches):
222 input, target = train_set.get_batch(b)
223 output = model.forward(Variable(input))
224 loss = criterion(output, Variable(target))
225 acc_loss = acc_loss + loss.data[0]
229 dt = (time.time() - start_t) / (e + 1)
230 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
231 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
235 ######################################################################
237 def nb_errors(model, data_set):
239 for b in range(0, data_set.nb_batches):
240 input, target = data_set.get_batch(b)
241 output = model.forward(Variable(input))
242 wta_prediction = output.data.max(1)[1].view(-1)
244 for i in range(0, data_set.batch_size):
245 if wta_prediction[i] != target[i]:
250 ######################################################################
252 for arg in vars(args):
253 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
255 ######################################################################
257 def int_to_suffix(n):
258 if n > 1000000 and n%1000000 == 0:
259 return str(n//1000000) + 'M'
260 elif n > 1000 and n%1000 == 0:
261 return str(n//1000) + 'K'
265 ######################################################################
267 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
268 print('The number of samples must be a multiple of the batch size.')
271 for problem_number in range(1, 24):
273 log_string('**** problem ' + str(problem_number) + ' ****')
276 model = AfrozeDeepNet()
278 model = AfrozeShallowNet()
280 if torch.cuda.is_available():
283 model_filename = model.name + '_' + \
284 str(problem_number) + '_' + \
285 int_to_suffix(args.nb_train_samples) + '.param'
288 for p in model.parameters(): nb_parameters += p.numel()
289 log_string('nb_parameters {:d}'.format(nb_parameters))
291 need_to_train = False
293 model.load_state_dict(torch.load(model_filename))
294 log_string('loaded_model ' + model_filename)
300 log_string('training_model ' + model_filename)
304 if args.compress_vignettes:
305 train_set = CompressedVignetteSet(problem_number,
306 args.nb_train_samples, args.batch_size,
307 cuda = torch.cuda.is_available())
309 train_set = VignetteSet(problem_number,
310 args.nb_train_samples, args.batch_size,
311 cuda = torch.cuda.is_available())
313 log_string('data_generation {:0.2f} samples / s'.format(
314 train_set.nb_samples / (time.time() - t))
317 train_model(model, train_set)
318 torch.save(model.state_dict(), model_filename)
319 log_string('saved_model ' + model_filename)
321 nb_train_errors = nb_errors(model, train_set)
323 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
325 100 * nb_train_errors / train_set.nb_samples,
327 train_set.nb_samples)
330 if need_to_train or args.test_loaded_models:
334 if args.compress_vignettes:
335 test_set = CompressedVignetteSet(problem_number,
336 args.nb_test_samples, args.batch_size,
337 cuda = torch.cuda.is_available())
339 test_set = VignetteSet(problem_number,
340 args.nb_test_samples, args.batch_size,
341 cuda = torch.cuda.is_available())
343 log_string('data_generation {:0.2f} samples / s'.format(
344 test_set.nb_samples / (time.time() - t))
347 nb_test_errors = nb_errors(model, test_set)
349 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
351 100 * nb_test_errors / test_set.nb_samples,
356 ######################################################################