3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with pysvrt. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 # Log and prints the string, with a time stamp. Does not log the
91 def log_string(s, remark = ''):
96 if pred_log_t is None:
99 elapsed = '+{:.02f}s'.format(t - pred_log_t)
103 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
106 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
108 ######################################################################
110 # Afroze's ShallowNet
113 # ----------------------
115 # -- conv(21x21 x 6) -> 108x108 6
116 # -- max(2x2) -> 54x54 6
117 # -- conv(19x19 x 16) -> 36x36 16
118 # -- max(2x2) -> 18x18 16
119 # -- conv(18x18 x 120) -> 1x1 120
120 # -- reshape -> 120 1
121 # -- full(120x84) -> 84 1
122 # -- full(84x2) -> 2 1
124 class AfrozeShallowNet(nn.Module):
126 super(AfrozeShallowNet, self).__init__()
127 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130 self.fc1 = nn.Linear(120, 84)
131 self.fc2 = nn.Linear(84, 2)
132 self.name = 'shallownet'
134 def forward(self, x):
135 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137 x = fn.relu(self.conv3(x))
139 x = fn.relu(self.fc1(x))
143 ######################################################################
147 class AfrozeDeepNet(nn.Module):
149 super(AfrozeDeepNet, self).__init__()
150 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
151 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
152 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
153 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
154 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
155 self.fc1 = nn.Linear(1536, 256)
156 self.fc2 = nn.Linear(256, 256)
157 self.fc3 = nn.Linear(256, 2)
158 self.name = 'deepnet'
160 def forward(self, x):
162 x = fn.max_pool2d(x, kernel_size=2)
166 x = fn.max_pool2d(x, kernel_size=2)
176 x = fn.max_pool2d(x, kernel_size=2)
191 ######################################################################
193 def train_model(model, train_set):
194 batch_size = args.batch_size
195 criterion = nn.CrossEntropyLoss()
197 if torch.cuda.is_available():
200 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
202 start_t = time.time()
204 for e in range(0, args.nb_epochs):
206 for b in range(0, train_set.nb_batches):
207 input, target = train_set.get_batch(b)
208 output = model.forward(Variable(input))
209 loss = criterion(output, Variable(target))
210 acc_loss = acc_loss + loss.data[0]
214 dt = (time.time() - start_t) / (e + 1)
215 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
216 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
220 ######################################################################
222 def nb_errors(model, data_set):
224 for b in range(0, data_set.nb_batches):
225 input, target = data_set.get_batch(b)
226 output = model.forward(Variable(input))
227 wta_prediction = output.data.max(1)[1].view(-1)
229 for i in range(0, data_set.batch_size):
230 if wta_prediction[i] != target[i]:
235 ######################################################################
237 for arg in vars(args):
238 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
240 ######################################################################
242 def int_to_suffix(n):
243 if n >= 1000000 and n%1000000 == 0:
244 return str(n//1000000) + 'M'
245 elif n >= 1000 and n%1000 == 0:
246 return str(n//1000) + 'K'
250 ######################################################################
252 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
253 print('The number of samples must be a multiple of the batch size.')
256 if args.compress_vignettes:
257 log_string('using_compressed_vignettes')
258 VignetteSet = vignette_set.CompressedVignetteSet
260 log_string('using_uncompressed_vignettes')
261 VignetteSet = vignette_set.VignetteSet
263 for problem_number in range(1, 24):
265 log_string('############### problem ' + str(problem_number) + ' ###############')
268 model = AfrozeDeepNet()
270 model = AfrozeShallowNet()
272 if torch.cuda.is_available(): model.cuda()
274 model_filename = model.name + '_pb:' + \
275 str(problem_number) + '_ns:' + \
276 int_to_suffix(args.nb_train_samples) + '.param'
279 for p in model.parameters(): nb_parameters += p.numel()
280 log_string('nb_parameters {:d}'.format(nb_parameters))
282 ##################################################
283 # Tries to load the model
285 need_to_train = False
287 model.load_state_dict(torch.load(model_filename))
288 log_string('loaded_model ' + model_filename)
292 ##################################################
297 log_string('training_model ' + model_filename)
301 train_set = VignetteSet(problem_number,
302 args.nb_train_samples, args.batch_size,
303 cuda = torch.cuda.is_available())
305 log_string('data_generation {:0.2f} samples / s'.format(
306 train_set.nb_samples / (time.time() - t))
309 train_model(model, train_set)
310 torch.save(model.state_dict(), model_filename)
311 log_string('saved_model ' + model_filename)
313 nb_train_errors = nb_errors(model, train_set)
315 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
317 100 * nb_train_errors / train_set.nb_samples,
319 train_set.nb_samples)
322 ##################################################
325 if need_to_train or args.test_loaded_models:
329 test_set = VignetteSet(problem_number,
330 args.nb_test_samples, args.batch_size,
331 cuda = torch.cuda.is_available())
333 log_string('data_generation {:0.2f} samples / s'.format(
334 test_set.nb_samples / (time.time() - t))
337 nb_test_errors = nb_errors(model, test_set)
339 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
341 100 * nb_test_errors / test_set.nb_samples,
346 ######################################################################