3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_samples',
53 type = int, default = 100000)
55 parser.add_argument('--nb_test_samples',
56 type = int, default = 10000)
58 parser.add_argument('--nb_epochs',
59 type = int, default = 50)
61 parser.add_argument('--batch_size',
62 type = int, default = 100)
64 parser.add_argument('--log_file',
65 type = str, default = 'default.log')
67 parser.add_argument('--compress_vignettes',
68 action='store_true', default = True,
69 help = 'Use lossless compression to reduce the memory footprint')
71 parser.add_argument('--deep_model',
72 action='store_true', default = True,
73 help = 'Use Afroze\'s Alexnet-like deep model')
75 parser.add_argument('--test_loaded_models',
76 action='store_true', default = False,
77 help = 'Should we compute the test errors of loaded models')
79 args = parser.parse_args()
81 ######################################################################
83 log_file = open(args.log_file, 'w')
86 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88 # Log and prints the string, with a time stamp. Does not log the
90 def log_string(s, remark = ''):
95 if pred_log_t is None:
98 elapsed = '+{:.02f}s'.format(t - pred_log_t)
102 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
105 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
107 ######################################################################
109 # Afroze's ShallowNet
112 # ----------------------
114 # -- conv(21x21 x 6) -> 108x108 6
115 # -- max(2x2) -> 54x54 6
116 # -- conv(19x19 x 16) -> 36x36 16
117 # -- max(2x2) -> 18x18 16
118 # -- conv(18x18 x 120) -> 1x1 120
119 # -- reshape -> 120 1
120 # -- full(120x84) -> 84 1
121 # -- full(84x2) -> 2 1
123 class AfrozeShallowNet(nn.Module):
125 super(AfrozeShallowNet, self).__init__()
126 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
127 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
128 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
129 self.fc1 = nn.Linear(120, 84)
130 self.fc2 = nn.Linear(84, 2)
131 self.name = 'shallownet'
133 def forward(self, x):
134 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
135 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
136 x = fn.relu(self.conv3(x))
138 x = fn.relu(self.fc1(x))
142 ######################################################################
147 # ----------------------
149 # -- conv(21x21 x 32 stride=4) -> 28x28 32
150 # -- max(2x2) -> 14x14 6
151 # -- conv(7x7 x 96) -> 8x8 16
152 # -- max(2x2) -> 4x4 16
153 # -- conv(5x5 x 96) -> 26x36 16
154 # -- conv(3x3 x 128) -> 36x36 16
155 # -- conv(3x3 x 128) -> 36x36 16
157 # -- conv(5x5 x 120) -> 1x1 120
158 # -- reshape -> 120 1
159 # -- full(3x84) -> 84 1
160 # -- full(84x2) -> 2 1
162 class AfrozeDeepNet(nn.Module):
164 super(AfrozeDeepNet, self).__init__()
165 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
166 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
167 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
168 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
169 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
170 self.fc1 = nn.Linear(1536, 256)
171 self.fc2 = nn.Linear(256, 256)
172 self.fc3 = nn.Linear(256, 2)
173 self.name = 'deepnet'
175 def forward(self, x):
177 x = fn.max_pool2d(x, kernel_size=2)
181 x = fn.max_pool2d(x, kernel_size=2)
191 x = fn.max_pool2d(x, kernel_size=2)
206 ######################################################################
208 def train_model(model, train_set):
209 batch_size = args.batch_size
210 criterion = nn.CrossEntropyLoss()
212 if torch.cuda.is_available():
215 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
217 start_t = time.time()
219 for e in range(0, args.nb_epochs):
221 for b in range(0, train_set.nb_batches):
222 input, target = train_set.get_batch(b)
223 output = model.forward(Variable(input))
224 loss = criterion(output, Variable(target))
225 acc_loss = acc_loss + loss.data[0]
229 dt = (time.time() - start_t) / (e + 1)
230 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
231 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
235 ######################################################################
237 def nb_errors(model, data_set):
239 for b in range(0, data_set.nb_batches):
240 input, target = data_set.get_batch(b)
241 output = model.forward(Variable(input))
242 wta_prediction = output.data.max(1)[1].view(-1)
244 for i in range(0, data_set.batch_size):
245 if wta_prediction[i] != target[i]:
250 ######################################################################
252 for arg in vars(args):
253 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
255 ######################################################################
257 def int_to_suffix(n):
258 if n > 1000000 and n%1000000 == 0:
259 return str(n//1000000) + 'M'
260 elif n > 1000 and n%1000 == 0:
261 return str(n//1000) + 'K'
265 ######################################################################
267 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
268 print('The number of samples must be a multiple of the batch size.')
271 if args.compress_vignettes:
272 VignetteSet = vignette_set.CompressedVignetteSet
274 VignetteSet = vignette_set.VignetteSet
276 for problem_number in range(1, 24):
278 log_string('############### problem ' + str(problem_number) + ' ###############')
281 model = AfrozeDeepNet()
283 model = AfrozeShallowNet()
285 if torch.cuda.is_available(): model.cuda()
287 model_filename = model.name + '_' + \
288 str(problem_number) + '_' + \
289 int_to_suffix(args.nb_train_samples) + '.param'
292 for p in model.parameters(): nb_parameters += p.numel()
293 log_string('nb_parameters {:d}'.format(nb_parameters))
295 ##################################################
296 # Tries to load the model
298 need_to_train = False
300 model.load_state_dict(torch.load(model_filename))
301 log_string('loaded_model ' + model_filename)
305 ##################################################
310 log_string('training_model ' + model_filename)
314 train_set = VignetteSet(problem_number,
315 args.nb_train_samples, args.batch_size,
316 cuda = torch.cuda.is_available())
318 log_string('data_generation {:0.2f} samples / s'.format(
319 train_set.nb_samples / (time.time() - t))
322 train_model(model, train_set)
323 torch.save(model.state_dict(), model_filename)
324 log_string('saved_model ' + model_filename)
326 nb_train_errors = nb_errors(model, train_set)
328 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
330 100 * nb_train_errors / train_set.nb_samples,
332 train_set.nb_samples)
335 ##################################################
338 if need_to_train or args.test_loaded_models:
342 test_set = VignetteSet(problem_number,
343 args.nb_test_samples, args.batch_size,
344 cuda = torch.cuda.is_available())
346 log_string('data_generation {:0.2f} samples / s'.format(
347 test_set.nb_samples / (time.time() - t))
350 nb_test_errors = nb_errors(model, test_set)
352 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
354 100 * nb_test_errors / test_set.nb_samples,
359 ######################################################################