3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 parser.add_argument('--problems',
81 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
82 help = 'What problems to process')
84 args = parser.parse_args()
86 ######################################################################
88 log_file = open(args.log_file, 'a')
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 # Log and prints the string, with a time stamp. Does not log the
95 def log_string(s, remark = ''):
100 if pred_log_t is None:
103 elapsed = '+{:.02f}s'.format(t - pred_log_t)
107 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
110 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
112 ######################################################################
114 # Afroze's ShallowNet
117 # ----------------------
119 # -- conv(21x21 x 6) -> 108x108 6
120 # -- max(2x2) -> 54x54 6
121 # -- conv(19x19 x 16) -> 36x36 16
122 # -- max(2x2) -> 18x18 16
123 # -- conv(18x18 x 120) -> 1x1 120
124 # -- reshape -> 120 1
125 # -- full(120x84) -> 84 1
126 # -- full(84x2) -> 2 1
128 class AfrozeShallowNet(nn.Module):
130 super(AfrozeShallowNet, self).__init__()
131 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134 self.fc1 = nn.Linear(120, 84)
135 self.fc2 = nn.Linear(84, 2)
136 self.name = 'shallownet'
138 def forward(self, x):
139 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141 x = fn.relu(self.conv3(x))
143 x = fn.relu(self.fc1(x))
147 ######################################################################
151 class AfrozeDeepNet(nn.Module):
153 super(AfrozeDeepNet, self).__init__()
154 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
155 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
156 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
157 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
158 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
159 self.fc1 = nn.Linear(1536, 256)
160 self.fc2 = nn.Linear(256, 256)
161 self.fc3 = nn.Linear(256, 2)
162 self.name = 'deepnet'
164 def forward(self, x):
166 x = fn.max_pool2d(x, kernel_size=2)
170 x = fn.max_pool2d(x, kernel_size=2)
180 x = fn.max_pool2d(x, kernel_size=2)
195 ######################################################################
197 def train_model(model, train_set):
198 batch_size = args.batch_size
199 criterion = nn.CrossEntropyLoss()
201 if torch.cuda.is_available():
204 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
206 start_t = time.time()
208 for e in range(0, args.nb_epochs):
210 for b in range(0, train_set.nb_batches):
211 input, target = train_set.get_batch(b)
212 output = model.forward(Variable(input))
213 loss = criterion(output, Variable(target))
214 acc_loss = acc_loss + loss.data[0]
218 dt = (time.time() - start_t) / (e + 1)
219 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
220 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
224 ######################################################################
226 def nb_errors(model, data_set):
228 for b in range(0, data_set.nb_batches):
229 input, target = data_set.get_batch(b)
230 output = model.forward(Variable(input))
231 wta_prediction = output.data.max(1)[1].view(-1)
233 for i in range(0, data_set.batch_size):
234 if wta_prediction[i] != target[i]:
239 ######################################################################
241 for arg in vars(args):
242 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
244 ######################################################################
246 def int_to_suffix(n):
247 if n >= 1000000 and n%1000000 == 0:
248 return str(n//1000000) + 'M'
249 elif n >= 1000 and n%1000 == 0:
250 return str(n//1000) + 'K'
254 class vignette_logger():
255 def __init__(self, delay_min = 60):
256 self.start_t = time.time()
257 self.last_t = self.start_t
258 self.delay_min = delay_min
260 def __call__(self, n, m):
262 if t > self.last_t + self.delay_min:
263 dt = (t - self.start_t) / m
264 log_string('sample_generation {:d} / {:d}'.format(
266 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
270 ######################################################################
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273 print('The number of samples must be a multiple of the batch size.')
276 log_string('############### start ###############')
278 if args.compress_vignettes:
279 log_string('using_compressed_vignettes')
280 VignetteSet = svrtset.CompressedVignetteSet
282 log_string('using_uncompressed_vignettes')
283 VignetteSet = svrtset.VignetteSet
285 for problem_number in map(int, args.problems.split(',')):
287 log_string('############### problem ' + str(problem_number) + ' ###############')
290 model = AfrozeDeepNet()
292 model = AfrozeShallowNet()
294 if torch.cuda.is_available(): model.cuda()
296 model_filename = model.name + '_pb:' + \
297 str(problem_number) + '_ns:' + \
298 int_to_suffix(args.nb_train_samples) + '.param'
301 for p in model.parameters(): nb_parameters += p.numel()
302 log_string('nb_parameters {:d}'.format(nb_parameters))
304 ##################################################
305 # Tries to load the model
307 need_to_train = False
309 model.load_state_dict(torch.load(model_filename))
310 log_string('loaded_model ' + model_filename)
314 ##################################################
319 log_string('training_model ' + model_filename)
323 train_set = VignetteSet(problem_number,
324 args.nb_train_samples, args.batch_size,
325 cuda = torch.cuda.is_available(),
326 logger = vignette_logger())
328 log_string('data_generation {:0.2f} samples / s'.format(
329 train_set.nb_samples / (time.time() - t))
332 train_model(model, train_set)
333 torch.save(model.state_dict(), model_filename)
334 log_string('saved_model ' + model_filename)
336 nb_train_errors = nb_errors(model, train_set)
338 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
340 100 * nb_train_errors / train_set.nb_samples,
342 train_set.nb_samples)
345 ##################################################
348 if need_to_train or args.test_loaded_models:
352 test_set = VignetteSet(problem_number,
353 args.nb_test_samples, args.batch_size,
354 cuda = torch.cuda.is_available())
356 log_string('data_generation {:0.2f} samples / s'.format(
357 test_set.nb_samples / (time.time() - t))
360 nb_test_errors = nb_errors(model, test_set)
362 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
364 100 * nb_test_errors / test_set.nb_samples,
369 ######################################################################