3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 parser.add_argument('--problems',
81 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
82 help = 'What problems to process')
84 args = parser.parse_args()
86 ######################################################################
88 log_file = open(args.log_file, 'w')
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 # Log and prints the string, with a time stamp. Does not log the
95 def log_string(s, remark = ''):
100 if pred_log_t is None:
103 elapsed = '+{:.02f}s'.format(t - pred_log_t)
107 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
110 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
112 ######################################################################
114 # Afroze's ShallowNet
117 # ----------------------
119 # -- conv(21x21 x 6) -> 108x108 6
120 # -- max(2x2) -> 54x54 6
121 # -- conv(19x19 x 16) -> 36x36 16
122 # -- max(2x2) -> 18x18 16
123 # -- conv(18x18 x 120) -> 1x1 120
124 # -- reshape -> 120 1
125 # -- full(120x84) -> 84 1
126 # -- full(84x2) -> 2 1
128 class AfrozeShallowNet(nn.Module):
130 super(AfrozeShallowNet, self).__init__()
131 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134 self.fc1 = nn.Linear(120, 84)
135 self.fc2 = nn.Linear(84, 2)
136 self.name = 'shallownet'
138 def forward(self, x):
139 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141 x = fn.relu(self.conv3(x))
143 x = fn.relu(self.fc1(x))
147 ######################################################################
151 class AfrozeDeepNet(nn.Module):
153 super(AfrozeDeepNet, self).__init__()
154 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
155 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
156 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
157 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
158 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
159 self.fc1 = nn.Linear(1536, 256)
160 self.fc2 = nn.Linear(256, 256)
161 self.fc3 = nn.Linear(256, 2)
162 self.name = 'deepnet'
164 def forward(self, x):
166 x = fn.max_pool2d(x, kernel_size=2)
170 x = fn.max_pool2d(x, kernel_size=2)
180 x = fn.max_pool2d(x, kernel_size=2)
195 ######################################################################
197 def train_model(model, train_set):
198 batch_size = args.batch_size
199 criterion = nn.CrossEntropyLoss()
201 if torch.cuda.is_available():
204 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
206 start_t = time.time()
208 for e in range(0, args.nb_epochs):
210 for b in range(0, train_set.nb_batches):
211 input, target = train_set.get_batch(b)
212 output = model.forward(Variable(input))
213 loss = criterion(output, Variable(target))
214 acc_loss = acc_loss + loss.data[0]
218 dt = (time.time() - start_t) / (e + 1)
219 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
220 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
224 ######################################################################
226 def nb_errors(model, data_set):
228 for b in range(0, data_set.nb_batches):
229 input, target = data_set.get_batch(b)
230 output = model.forward(Variable(input))
231 wta_prediction = output.data.max(1)[1].view(-1)
233 for i in range(0, data_set.batch_size):
234 if wta_prediction[i] != target[i]:
239 ######################################################################
241 for arg in vars(args):
242 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
244 ######################################################################
246 def int_to_suffix(n):
247 if n >= 1000000 and n%1000000 == 0:
248 return str(n//1000000) + 'M'
249 elif n >= 1000 and n%1000 == 0:
250 return str(n//1000) + 'K'
254 class vignette_logger():
255 def __init__(self, delay_min = 60):
256 self.start_t = time.time()
257 self.last_t = self.start_t
258 self.delay_min = delay_min
260 def __call__(self, n, m):
262 if t > self.last_t + self.delay_min:
263 dt = (t - self.start_t) / m
264 log_string('sample_generation {:d} / {:d}'.format(
266 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
270 ######################################################################
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273 print('The number of samples must be a multiple of the batch size.')
276 if args.compress_vignettes:
277 log_string('using_compressed_vignettes')
278 VignetteSet = svrtset.CompressedVignetteSet
280 log_string('using_uncompressed_vignettes')
281 VignetteSet = svrtset.VignetteSet
283 for problem_number in map(int, args.problems.split(',')):
285 log_string('############### problem ' + str(problem_number) + ' ###############')
288 model = AfrozeDeepNet()
290 model = AfrozeShallowNet()
292 if torch.cuda.is_available(): model.cuda()
294 model_filename = model.name + '_pb:' + \
295 str(problem_number) + '_ns:' + \
296 int_to_suffix(args.nb_train_samples) + '.param'
299 for p in model.parameters(): nb_parameters += p.numel()
300 log_string('nb_parameters {:d}'.format(nb_parameters))
302 ##################################################
303 # Tries to load the model
305 need_to_train = False
307 model.load_state_dict(torch.load(model_filename))
308 log_string('loaded_model ' + model_filename)
312 ##################################################
317 log_string('training_model ' + model_filename)
321 train_set = VignetteSet(problem_number,
322 args.nb_train_samples, args.batch_size,
323 cuda = torch.cuda.is_available(),
324 logger = vignette_logger())
326 log_string('data_generation {:0.2f} samples / s'.format(
327 train_set.nb_samples / (time.time() - t))
330 train_model(model, train_set)
331 torch.save(model.state_dict(), model_filename)
332 log_string('saved_model ' + model_filename)
334 nb_train_errors = nb_errors(model, train_set)
336 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
338 100 * nb_train_errors / train_set.nb_samples,
340 train_set.nb_samples)
343 ##################################################
346 if need_to_train or args.test_loaded_models:
350 test_set = VignetteSet(problem_number,
351 args.nb_test_samples, args.batch_size,
352 cuda = torch.cuda.is_available())
354 log_string('data_generation {:0.2f} samples / s'.format(
355 test_set.nb_samples / (time.time() - t))
358 nb_test_errors = nb_errors(model, test_set)
360 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
362 100 * nb_test_errors / test_set.nb_samples,
367 ######################################################################