3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_samples',
53 type = int, default = 100000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'default.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 parser.add_argument('--deep_model',
77 action='store_true', default = False,
78 help = 'Use Afroze\'s Alexnet-like deep model')
80 parser.add_argument('--test_loaded_models',
81 action='store_true', default = False,
82 help = 'Should we compute the test errors of loaded models')
84 args = parser.parse_args()
86 ######################################################################
88 log_file = open(args.log_file, 'w')
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 # Log and prints the string, with a time stamp. Does not log the
95 def log_string(s, remark = ''):
100 if pred_log_t is None:
103 elapsed = '+{:.02f}s'.format(t - pred_log_t)
107 s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
108 log_file.write(s + '\n')
110 print(s + Fore.CYAN + remark + Style.RESET_ALL)
112 ######################################################################
114 # Afroze's ShallowNet
117 # ----------------------
119 # -- conv(21x21 x 6) -> 108x108 6
120 # -- max(2x2) -> 54x54 6
121 # -- conv(19x19 x 16) -> 36x36 16
122 # -- max(2x2) -> 18x18 16
123 # -- conv(18x18 x 120) -> 1x1 120
124 # -- reshape -> 120 1
125 # -- full(120x84) -> 84 1
126 # -- full(84x2) -> 2 1
128 class AfrozeShallowNet(nn.Module):
130 super(AfrozeShallowNet, self).__init__()
131 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134 self.fc1 = nn.Linear(120, 84)
135 self.fc2 = nn.Linear(84, 2)
136 self.name = 'shallownet'
138 def forward(self, x):
139 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141 x = fn.relu(self.conv3(x))
143 x = fn.relu(self.fc1(x))
147 ######################################################################
152 # ----------------------
154 # -- conv(21x21 x 32 stride=4) -> 28x28 32
155 # -- max(2x2) -> 14x14 6
156 # -- conv(7x7 x 96) -> 8x8 16
157 # -- max(2x2) -> 4x4 16
158 # -- conv(5x5 x 96) -> 26x36 16
159 # -- conv(3x3 x 128) -> 36x36 16
160 # -- conv(3x3 x 128) -> 36x36 16
162 # -- conv(5x5 x 120) -> 1x1 120
163 # -- reshape -> 120 1
164 # -- full(3x84) -> 84 1
165 # -- full(84x2) -> 2 1
167 class AfrozeDeepNet(nn.Module):
169 super(AfrozeDeepNet, self).__init__()
170 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
171 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
172 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
173 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
174 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
175 self.fc1 = nn.Linear(1536, 256)
176 self.fc2 = nn.Linear(256, 256)
177 self.fc3 = nn.Linear(256, 2)
178 self.name = 'deepnet'
180 def forward(self, x):
182 x = fn.max_pool2d(x, kernel_size=2)
186 x = fn.max_pool2d(x, kernel_size=2)
196 x = fn.max_pool2d(x, kernel_size=2)
211 ######################################################################
213 def train_model(model, train_set):
214 batch_size = args.batch_size
215 criterion = nn.CrossEntropyLoss()
217 if torch.cuda.is_available():
220 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
222 start_t = time.time()
224 for e in range(0, args.nb_epochs):
226 for b in range(0, train_set.nb_batches):
227 input, target = train_set.get_batch(b)
228 output = model.forward(Variable(input))
229 loss = criterion(output, Variable(target))
230 acc_loss = acc_loss + loss.data[0]
234 dt = (time.time() - start_t) / (e + 1)
235 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
236 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
240 ######################################################################
242 def nb_errors(model, data_set):
244 for b in range(0, data_set.nb_batches):
245 input, target = data_set.get_batch(b)
246 output = model.forward(Variable(input))
247 wta_prediction = output.data.max(1)[1].view(-1)
249 for i in range(0, data_set.batch_size):
250 if wta_prediction[i] != target[i]:
255 ######################################################################
257 for arg in vars(args):
258 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
260 ######################################################################
262 def int_to_suffix(n):
263 if n > 1000000 and n%1000000 == 0:
264 return str(n//1000000) + 'M'
265 elif n > 1000 and n%1000 == 0:
266 return str(n//1000) + 'K'
270 ######################################################################
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273 print('The number of samples must be a multiple of the batch size.')
276 for problem_number in range(1, 24):
278 log_string('**** problem ' + str(problem_number) + ' ****')
281 model = AfrozeDeepNet()
283 model = AfrozeShallowNet()
285 if torch.cuda.is_available():
288 model_filename = model.name + '_' + \
289 str(problem_number) + '_' + \
290 int_to_suffix(args.nb_train_samples) + '.param'
293 for p in model.parameters(): nb_parameters += p.numel()
294 log_string('nb_parameters {:d}'.format(nb_parameters))
296 need_to_train = False
298 model.load_state_dict(torch.load(model_filename))
299 log_string('loaded_model ' + model_filename)
305 log_string('training_model ' + model_filename)
309 if args.compress_vignettes:
310 train_set = CompressedVignetteSet(problem_number,
311 args.nb_train_samples, args.batch_size,
312 cuda = torch.cuda.is_available())
314 train_set = VignetteSet(problem_number,
315 args.nb_train_samples, args.batch_size,
316 cuda = torch.cuda.is_available())
318 log_string('data_generation {:0.2f} samples / s'.format(
319 train_set.nb_samples / (time.time() - t))
322 train_model(model, train_set)
323 torch.save(model.state_dict(), model_filename)
324 log_string('saved_model ' + model_filename)
326 nb_train_errors = nb_errors(model, train_set)
328 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
330 100 * nb_train_errors / train_set.nb_samples,
332 train_set.nb_samples)
335 if need_to_train or args.test_loaded_models:
339 if args.compress_vignettes:
340 test_set = CompressedVignetteSet(problem_number,
341 args.nb_test_samples, args.batch_size,
342 cuda = torch.cuda.is_available())
344 test_set = VignetteSet(problem_number,
345 args.nb_test_samples, args.batch_size,
346 cuda = torch.cuda.is_available())
348 log_string('data_generation {:0.2f} samples / s'.format(
349 test_set.nb_samples / (time.time() - t))
352 nb_test_errors = nb_errors(model, test_set)
354 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
356 100 * nb_test_errors / test_set.nb_samples,
361 ######################################################################