3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 # Log and prints the string, with a time stamp. Does not log the
91 def log_string(s, remark = ''):
96 if pred_log_t is None:
99 elapsed = '+{:.02f}s'.format(t - pred_log_t)
103 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
106 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
108 ######################################################################
110 # Afroze's ShallowNet
113 # ----------------------
115 # -- conv(21x21 x 6) -> 108x108 6
116 # -- max(2x2) -> 54x54 6
117 # -- conv(19x19 x 16) -> 36x36 16
118 # -- max(2x2) -> 18x18 16
119 # -- conv(18x18 x 120) -> 1x1 120
120 # -- reshape -> 120 1
121 # -- full(120x84) -> 84 1
122 # -- full(84x2) -> 2 1
124 class AfrozeShallowNet(nn.Module):
126 super(AfrozeShallowNet, self).__init__()
127 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130 self.fc1 = nn.Linear(120, 84)
131 self.fc2 = nn.Linear(84, 2)
132 self.name = 'shallownet'
134 def forward(self, x):
135 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137 x = fn.relu(self.conv3(x))
139 x = fn.relu(self.fc1(x))
143 ######################################################################
147 class AfrozeDeepNet(nn.Module):
149 super(AfrozeDeepNet, self).__init__()
150 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
151 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
152 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
153 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
154 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
155 self.fc1 = nn.Linear(1536, 256)
156 self.fc2 = nn.Linear(256, 256)
157 self.fc3 = nn.Linear(256, 2)
158 self.name = 'deepnet'
160 def forward(self, x):
162 x = fn.max_pool2d(x, kernel_size=2)
166 x = fn.max_pool2d(x, kernel_size=2)
176 x = fn.max_pool2d(x, kernel_size=2)
191 ######################################################################
193 def train_model(model, train_set):
194 batch_size = args.batch_size
195 criterion = nn.CrossEntropyLoss()
197 if torch.cuda.is_available():
200 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
202 start_t = time.time()
204 for e in range(0, args.nb_epochs):
206 for b in range(0, train_set.nb_batches):
207 input, target = train_set.get_batch(b)
208 output = model.forward(Variable(input))
209 loss = criterion(output, Variable(target))
210 acc_loss = acc_loss + loss.data[0]
214 dt = (time.time() - start_t) / (e + 1)
215 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
216 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
220 ######################################################################
222 def nb_errors(model, data_set):
224 for b in range(0, data_set.nb_batches):
225 input, target = data_set.get_batch(b)
226 output = model.forward(Variable(input))
227 wta_prediction = output.data.max(1)[1].view(-1)
229 for i in range(0, data_set.batch_size):
230 if wta_prediction[i] != target[i]:
235 ######################################################################
237 for arg in vars(args):
238 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
240 ######################################################################
242 def int_to_suffix(n):
243 if n >= 1000000 and n%1000000 == 0:
244 return str(n//1000000) + 'M'
245 elif n >= 1000 and n%1000 == 0:
246 return str(n//1000) + 'K'
250 class vignette_logger():
251 def __init__(self, delay_min = 60):
252 self.start_t = time.time()
253 self.delay_min = delay_min
255 def __call__(self, n, m):
257 if t > self.start_t + self.delay_min:
258 dt = (t - self.start_t) / m
259 log_string('sample_generation {:d} / {:d}'.format(
261 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
264 ######################################################################
266 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
267 print('The number of samples must be a multiple of the batch size.')
270 if args.compress_vignettes:
271 log_string('using_compressed_vignettes')
272 VignetteSet = svrtset.CompressedVignetteSet
274 log_string('using_uncompressed_vignettes')
275 VignetteSet = svrtset.VignetteSet
277 for problem_number in range(1, 24):
279 log_string('############### problem ' + str(problem_number) + ' ###############')
282 model = AfrozeDeepNet()
284 model = AfrozeShallowNet()
286 if torch.cuda.is_available(): model.cuda()
288 model_filename = model.name + '_pb:' + \
289 str(problem_number) + '_ns:' + \
290 int_to_suffix(args.nb_train_samples) + '.param'
293 for p in model.parameters(): nb_parameters += p.numel()
294 log_string('nb_parameters {:d}'.format(nb_parameters))
296 ##################################################
297 # Tries to load the model
299 need_to_train = False
301 model.load_state_dict(torch.load(model_filename))
302 log_string('loaded_model ' + model_filename)
306 ##################################################
311 log_string('training_model ' + model_filename)
315 train_set = VignetteSet(problem_number,
316 args.nb_train_samples, args.batch_size,
317 cuda = torch.cuda.is_available(),
318 logger = vignette_logger())
320 log_string('data_generation {:0.2f} samples / s'.format(
321 train_set.nb_samples / (time.time() - t))
324 train_model(model, train_set)
325 torch.save(model.state_dict(), model_filename)
326 log_string('saved_model ' + model_filename)
328 nb_train_errors = nb_errors(model, train_set)
330 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
332 100 * nb_train_errors / train_set.nb_samples,
334 train_set.nb_samples)
337 ##################################################
340 if need_to_train or args.test_loaded_models:
344 test_set = VignetteSet(problem_number,
345 args.nb_test_samples, args.batch_size,
346 cuda = torch.cuda.is_available())
348 log_string('data_generation {:0.2f} samples / s'.format(
349 test_set.nb_samples / (time.time() - t))
352 nb_test_errors = nb_errors(model, test_set)
354 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
356 100 * nb_test_errors / test_set.nb_samples,
361 ######################################################################