3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'default.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 parser.add_argument('--deep_model',
77 action='store_true', default = False,
78 help = 'Use Afroze\'s Alexnet-like deep model')
80 parser.add_argument('--test_loaded_models',
81 action='store_true', default = False,
82 help = 'Should we compute the test errors of loaded models')
84 args = parser.parse_args()
86 ######################################################################
88 log_file = open(args.log_file, 'w')
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 # Log and prints the string, with a time stamp. Does not log the
95 def log_string(s, remark = ''):
100 if pred_log_t is None:
103 elapsed = '+{:.02f}s'.format(t - pred_log_t)
107 s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
108 log_file.write(s + '\n')
110 print(s + Fore.CYAN + remark + Style.RESET_ALL)
112 ######################################################################
114 # Afroze's ShallowNet
117 # ----------------------
119 # -- conv(21x21 x 6) -> 108x108 6
120 # -- max(2x2) -> 54x54 6
121 # -- conv(19x19 x 16) -> 36x36 16
122 # -- max(2x2) -> 18x18 16
123 # -- conv(18x18 x 120) -> 1x1 120
124 # -- reshape -> 120 1
125 # -- full(120x84) -> 84 1
126 # -- full(84x2) -> 2 1
128 class AfrozeShallowNet(nn.Module):
130 super(AfrozeShallowNet, self).__init__()
131 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134 self.fc1 = nn.Linear(120, 84)
135 self.fc2 = nn.Linear(84, 2)
136 self.name = 'shallownet'
138 def forward(self, x):
139 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141 x = fn.relu(self.conv3(x))
143 x = fn.relu(self.fc1(x))
147 ######################################################################
152 # ----------------------
154 # -- conv(21x21 x 32 stride=4) -> 28x28 32
155 # -- max(2x2) -> 14x14 6
156 # -- conv(7x7 x 96) -> 8x8 16
157 # -- max(2x2) -> 4x4 16
158 # -- conv(5x5 x 96) -> 26x36 16
159 # -- conv(3x3 x 128) -> 36x36 16
160 # -- conv(3x3 x 128) -> 36x36 16
162 # -- conv(5x5 x 120) -> 1x1 120
163 # -- reshape -> 120 1
164 # -- full(3x84) -> 84 1
165 # -- full(84x2) -> 2 1
167 class AfrozeDeepNet(nn.Module):
169 super(AfrozeDeepNet, self).__init__()
170 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
171 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
172 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
173 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
174 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
175 self.fc1 = nn.Linear(1536, 256)
176 self.fc2 = nn.Linear(256, 256)
177 self.fc3 = nn.Linear(256, 2)
178 self.name = 'deepnet'
180 def forward(self, x):
182 x = fn.max_pool2d(x, kernel_size=2)
186 x = fn.max_pool2d(x, kernel_size=2)
196 x = fn.max_pool2d(x, kernel_size=2)
211 ######################################################################
213 def train_model(model, train_set):
214 batch_size = args.batch_size
215 criterion = nn.CrossEntropyLoss()
217 if torch.cuda.is_available():
220 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
222 start_t = time.time()
224 for e in range(0, args.nb_epochs):
226 for b in range(0, train_set.nb_batches):
227 input, target = train_set.get_batch(b)
228 output = model.forward(Variable(input))
229 loss = criterion(output, Variable(target))
230 acc_loss = acc_loss + loss.data[0]
234 dt = (time.time() - start_t) / (e + 1)
235 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
236 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
240 ######################################################################
242 def nb_errors(model, data_set):
244 for b in range(0, data_set.nb_batches):
245 input, target = data_set.get_batch(b)
246 output = model.forward(Variable(input))
247 wta_prediction = output.data.max(1)[1].view(-1)
249 for i in range(0, data_set.batch_size):
250 if wta_prediction[i] != target[i]:
255 ######################################################################
257 for arg in vars(args):
258 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
260 ######################################################################
262 for problem_number in range(1, 24):
264 log_string('**** problem ' + str(problem_number) + ' ****')
267 model = AfrozeDeepNet()
269 model = AfrozeShallowNet()
271 if torch.cuda.is_available():
274 model_filename = model.name + '_' + \
275 str(problem_number) + '_' + \
276 str(args.nb_train_batches) + '.param'
279 for p in model.parameters(): nb_parameters += p.numel()
280 log_string('nb_parameters {:d}'.format(nb_parameters))
282 need_to_train = False
284 model.load_state_dict(torch.load(model_filename))
285 log_string('loaded_model ' + model_filename)
291 log_string('training_model ' + model_filename)
295 if args.compress_vignettes:
296 train_set = CompressedVignetteSet(problem_number,
297 args.nb_train_batches, args.batch_size,
298 cuda=torch.cuda.is_available())
300 train_set = VignetteSet(problem_number,
301 args.nb_train_batches, args.batch_size,
302 cuda=torch.cuda.is_available())
304 log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
306 train_model(model, train_set)
307 torch.save(model.state_dict(), model_filename)
308 log_string('saved_model ' + model_filename)
310 nb_train_errors = nb_errors(model, train_set)
312 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
314 100 * nb_train_errors / train_set.nb_samples,
316 train_set.nb_samples)
319 if need_to_train or args.test_loaded_models:
323 if args.compress_vignettes:
324 test_set = CompressedVignetteSet(problem_number,
325 args.nb_test_batches, args.batch_size,
326 cuda=torch.cuda.is_available())
328 test_set = VignetteSet(problem_number,
329 args.nb_test_batches, args.batch_size,
330 cuda=torch.cuda.is_available())
332 log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
334 nb_test_errors = nb_errors(model, test_set)
336 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
338 100 * nb_test_errors / test_set.nb_samples,
343 ######################################################################