3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 # Log and prints the string, with a time stamp. Does not log the
91 def log_string(s, remark = ''):
96 if pred_log_t is None:
99 elapsed = '+{:.02f}s'.format(t - pred_log_t)
103 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
106 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
108 ######################################################################
110 # Afroze's ShallowNet
113 # ----------------------
115 # -- conv(21x21 x 6) -> 108x108 6
116 # -- max(2x2) -> 54x54 6
117 # -- conv(19x19 x 16) -> 36x36 16
118 # -- max(2x2) -> 18x18 16
119 # -- conv(18x18 x 120) -> 1x1 120
120 # -- reshape -> 120 1
121 # -- full(120x84) -> 84 1
122 # -- full(84x2) -> 2 1
124 class AfrozeShallowNet(nn.Module):
126 super(AfrozeShallowNet, self).__init__()
127 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130 self.fc1 = nn.Linear(120, 84)
131 self.fc2 = nn.Linear(84, 2)
132 self.name = 'shallownet'
134 def forward(self, x):
135 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137 x = fn.relu(self.conv3(x))
139 x = fn.relu(self.fc1(x))
143 ######################################################################
147 class AfrozeDeepNet(nn.Module):
149 super(AfrozeDeepNet, self).__init__()
150 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
151 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
152 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
153 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
154 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
155 self.fc1 = nn.Linear(1536, 256)
156 self.fc2 = nn.Linear(256, 256)
157 self.fc3 = nn.Linear(256, 2)
158 self.name = 'deepnet'
160 def forward(self, x):
162 x = fn.max_pool2d(x, kernel_size=2)
166 x = fn.max_pool2d(x, kernel_size=2)
176 x = fn.max_pool2d(x, kernel_size=2)
191 ######################################################################
193 def train_model(model, train_set):
194 batch_size = args.batch_size
195 criterion = nn.CrossEntropyLoss()
197 if torch.cuda.is_available():
200 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
202 start_t = time.time()
204 for e in range(0, args.nb_epochs):
206 for b in range(0, train_set.nb_batches):
207 input, target = train_set.get_batch(b)
208 output = model.forward(Variable(input))
209 loss = criterion(output, Variable(target))
210 acc_loss = acc_loss + loss.data[0]
214 dt = (time.time() - start_t) / (e + 1)
215 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
216 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
220 ######################################################################
222 def nb_errors(model, data_set):
224 for b in range(0, data_set.nb_batches):
225 input, target = data_set.get_batch(b)
226 output = model.forward(Variable(input))
227 wta_prediction = output.data.max(1)[1].view(-1)
229 for i in range(0, data_set.batch_size):
230 if wta_prediction[i] != target[i]:
235 ######################################################################
237 for arg in vars(args):
238 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
240 ######################################################################
242 def int_to_suffix(n):
243 if n >= 1000000 and n%1000000 == 0:
244 return str(n//1000000) + 'M'
245 elif n >= 1000 and n%1000 == 0:
246 return str(n//1000) + 'K'
250 class vignette_logger():
251 def __init__(self, delay_min = 60):
252 self.start_t = time.time()
253 self.delay_min = delay_min
255 def __call__(self, n, m):
257 if t > self.start_t + self.delay_min:
258 dt = (t - self.start_t) / m
259 log_string('sample_generation {:d} / {:d}'.format(
261 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
265 ######################################################################
267 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
268 print('The number of samples must be a multiple of the batch size.')
271 if args.compress_vignettes:
272 log_string('using_compressed_vignettes')
273 VignetteSet = svrtset.CompressedVignetteSet
275 log_string('using_uncompressed_vignettes')
276 VignetteSet = svrtset.VignetteSet
278 for problem_number in range(1, 24):
280 log_string('############### problem ' + str(problem_number) + ' ###############')
283 model = AfrozeDeepNet()
285 model = AfrozeShallowNet()
287 if torch.cuda.is_available(): model.cuda()
289 model_filename = model.name + '_pb:' + \
290 str(problem_number) + '_ns:' + \
291 int_to_suffix(args.nb_train_samples) + '.param'
294 for p in model.parameters(): nb_parameters += p.numel()
295 log_string('nb_parameters {:d}'.format(nb_parameters))
297 ##################################################
298 # Tries to load the model
300 need_to_train = False
302 model.load_state_dict(torch.load(model_filename))
303 log_string('loaded_model ' + model_filename)
307 ##################################################
312 log_string('training_model ' + model_filename)
316 train_set = VignetteSet(problem_number,
317 args.nb_train_samples, args.batch_size,
318 cuda = torch.cuda.is_available(),
319 logger = vignette_logger())
321 log_string('data_generation {:0.2f} samples / s'.format(
322 train_set.nb_samples / (time.time() - t))
325 train_model(model, train_set)
326 torch.save(model.state_dict(), model_filename)
327 log_string('saved_model ' + model_filename)
329 nb_train_errors = nb_errors(model, train_set)
331 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
333 100 * nb_train_errors / train_set.nb_samples,
335 train_set.nb_samples)
338 ##################################################
341 if need_to_train or args.test_loaded_models:
345 test_set = VignetteSet(problem_number,
346 args.nb_test_samples, args.batch_size,
347 cuda = torch.cuda.is_available())
349 log_string('data_generation {:0.2f} samples / s'.format(
350 test_set.nb_samples / (time.time() - t))
353 nb_test_errors = nb_errors(model, test_set)
355 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
357 100 * nb_test_errors / test_set.nb_samples,
362 ######################################################################