3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 # Log and prints the string, with a time stamp. Does not log the
91 def log_string(s, remark = ''):
96 if pred_log_t is None:
99 elapsed = '+{:.02f}s'.format(t - pred_log_t)
103 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
106 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
108 ######################################################################
110 # Afroze's ShallowNet
113 # ----------------------
115 # -- conv(21x21 x 6) -> 108x108 6
116 # -- max(2x2) -> 54x54 6
117 # -- conv(19x19 x 16) -> 36x36 16
118 # -- max(2x2) -> 18x18 16
119 # -- conv(18x18 x 120) -> 1x1 120
120 # -- reshape -> 120 1
121 # -- full(120x84) -> 84 1
122 # -- full(84x2) -> 2 1
124 class AfrozeShallowNet(nn.Module):
126 super(AfrozeShallowNet, self).__init__()
127 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130 self.fc1 = nn.Linear(120, 84)
131 self.fc2 = nn.Linear(84, 2)
132 self.name = 'shallownet'
134 def forward(self, x):
135 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137 x = fn.relu(self.conv3(x))
139 x = fn.relu(self.fc1(x))
143 ######################################################################
147 class AfrozeDeepNet(nn.Module):
149 super(AfrozeDeepNet, self).__init__()
150 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
151 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
152 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
153 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
154 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
155 self.fc1 = nn.Linear(1536, 256)
156 self.fc2 = nn.Linear(256, 256)
157 self.fc3 = nn.Linear(256, 2)
158 self.name = 'deepnet'
160 def forward(self, x):
162 x = fn.max_pool2d(x, kernel_size=2)
166 x = fn.max_pool2d(x, kernel_size=2)
176 x = fn.max_pool2d(x, kernel_size=2)
191 ######################################################################
193 def train_model(model, train_set):
194 batch_size = args.batch_size
195 criterion = nn.CrossEntropyLoss()
197 if torch.cuda.is_available():
200 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
202 start_t = time.time()
204 for e in range(0, args.nb_epochs):
206 for b in range(0, train_set.nb_batches):
207 input, target = train_set.get_batch(b)
208 output = model.forward(Variable(input))
209 loss = criterion(output, Variable(target))
210 acc_loss = acc_loss + loss.data[0]
214 dt = (time.time() - start_t) / (e + 1)
215 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
216 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
220 ######################################################################
222 def nb_errors(model, data_set):
224 for b in range(0, data_set.nb_batches):
225 input, target = data_set.get_batch(b)
226 output = model.forward(Variable(input))
227 wta_prediction = output.data.max(1)[1].view(-1)
229 for i in range(0, data_set.batch_size):
230 if wta_prediction[i] != target[i]:
235 ######################################################################
237 for arg in vars(args):
238 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
240 ######################################################################
242 def int_to_suffix(n):
243 if n >= 1000000 and n%1000000 == 0:
244 return str(n//1000000) + 'M'
245 elif n >= 1000 and n%1000 == 0:
246 return str(n//1000) + 'K'
250 class vignette_logger():
251 def __init__(self, delay_min = 60):
252 self.start_t = time.time()
253 self.last_t = self.start_t
254 self.delay_min = delay_min
256 def __call__(self, n, m):
258 if t > self.last_t + self.delay_min:
259 dt = (t - self.start_t) / m
260 log_string('sample_generation {:d} / {:d}'.format(
262 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
266 ######################################################################
268 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
269 print('The number of samples must be a multiple of the batch size.')
272 if args.compress_vignettes:
273 log_string('using_compressed_vignettes')
274 VignetteSet = svrtset.CompressedVignetteSet
276 log_string('using_uncompressed_vignettes')
277 VignetteSet = svrtset.VignetteSet
279 for problem_number in range(1, 24):
281 log_string('############### problem ' + str(problem_number) + ' ###############')
284 model = AfrozeDeepNet()
286 model = AfrozeShallowNet()
288 if torch.cuda.is_available(): model.cuda()
290 model_filename = model.name + '_pb:' + \
291 str(problem_number) + '_ns:' + \
292 int_to_suffix(args.nb_train_samples) + '.param'
295 for p in model.parameters(): nb_parameters += p.numel()
296 log_string('nb_parameters {:d}'.format(nb_parameters))
298 ##################################################
299 # Tries to load the model
301 need_to_train = False
303 model.load_state_dict(torch.load(model_filename))
304 log_string('loaded_model ' + model_filename)
308 ##################################################
313 log_string('training_model ' + model_filename)
317 train_set = VignetteSet(problem_number,
318 args.nb_train_samples, args.batch_size,
319 cuda = torch.cuda.is_available(),
320 logger = vignette_logger())
322 log_string('data_generation {:0.2f} samples / s'.format(
323 train_set.nb_samples / (time.time() - t))
326 train_model(model, train_set)
327 torch.save(model.state_dict(), model_filename)
328 log_string('saved_model ' + model_filename)
330 nb_train_errors = nb_errors(model, train_set)
332 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
334 100 * nb_train_errors / train_set.nb_samples,
336 train_set.nb_samples)
339 ##################################################
342 if need_to_train or args.test_loaded_models:
346 test_set = VignetteSet(problem_number,
347 args.nb_test_samples, args.batch_size,
348 cuda = torch.cuda.is_available())
350 log_string('data_generation {:0.2f} samples / s'.format(
351 test_set.nb_samples / (time.time() - t))
354 nb_test_errors = nb_errors(model, test_set)
356 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
358 100 * nb_test_errors / test_set.nb_samples,
363 ######################################################################