3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'cnn-svrt.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 parser.add_argument('--test_loaded_models',
77 action='store_true', default = False,
78 help = 'Should we compute the test error of models we load')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 if pred_log_t is None:
96 elapsed = '+{:.02f}s'.format(t - pred_log_t)
98 s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
99 log_file.write(s + '\n')
103 ######################################################################
105 # Afroze's ShallowNet
108 # ----------------------
110 # -- conv(21x21 x 6) -> 108x108 6
111 # -- max(2x2) -> 54x54 6
112 # -- conv(19x19 x 16) -> 36x36 16
113 # -- max(2x2) -> 18x18 16
114 # -- conv(18x18 x 120) -> 1x1 120
115 # -- reshape -> 120 1
116 # -- full(120x84) -> 84 1
117 # -- full(84x2) -> 2 1
119 class AfrozeShallowNet(nn.Module):
121 super(AfrozeShallowNet, self).__init__()
122 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
123 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
124 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
125 self.fc1 = nn.Linear(120, 84)
126 self.fc2 = nn.Linear(84, 2)
127 self.name = 'shallownet'
129 def forward(self, x):
130 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
131 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
132 x = fn.relu(self.conv3(x))
134 x = fn.relu(self.fc1(x))
138 ######################################################################
140 def train_model(model, train_set):
141 batch_size = args.batch_size
142 criterion = nn.CrossEntropyLoss()
144 if torch.cuda.is_available():
147 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
149 for e in range(0, args.nb_epochs):
151 for b in range(0, train_set.nb_batches):
152 input, target = train_set.get_batch(b)
153 output = model.forward(Variable(input))
154 loss = criterion(output, Variable(target))
155 acc_loss = acc_loss + loss.data[0]
159 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
163 ######################################################################
165 def nb_errors(model, data_set):
167 for b in range(0, data_set.nb_batches):
168 input, target = data_set.get_batch(b)
169 output = model.forward(Variable(input))
170 wta_prediction = output.data.max(1)[1].view(-1)
172 for i in range(0, data_set.batch_size):
173 if wta_prediction[i] != target[i]:
178 ######################################################################
180 for arg in vars(args):
181 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
183 ######################################################################
185 for problem_number in range(1, 24):
187 log_string('**** problem ' + str(problem_number) + ' ****')
189 model = AfrozeShallowNet()
191 if torch.cuda.is_available():
194 model_filename = model.name + '_' + \
195 str(problem_number) + '_' + \
196 str(args.nb_train_batches) + '.param'
199 for p in model.parameters(): nb_parameters += p.numel()
200 log_string('nb_parameters {:d}'.format(nb_parameters))
202 need_to_train = False
204 model.load_state_dict(torch.load(model_filename))
205 log_string('loaded_model ' + model_filename)
211 log_string('training_model ' + model_filename)
215 if args.compress_vignettes:
216 train_set = CompressedVignetteSet(problem_number,
217 args.nb_train_batches, args.batch_size,
218 cuda=torch.cuda.is_available())
220 train_set = VignetteSet(problem_number,
221 args.nb_train_batches, args.batch_size,
222 cuda=torch.cuda.is_available())
224 log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
226 train_model(model, train_set)
227 torch.save(model.state_dict(), model_filename)
228 log_string('saved_model ' + model_filename)
230 nb_train_errors = nb_errors(model, train_set)
232 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
234 100 * nb_train_errors / train_set.nb_samples,
236 train_set.nb_samples)
239 if need_to_train or args.test_loaded_models:
243 if args.compress_vignettes:
244 test_set = CompressedVignetteSet(problem_number,
245 args.nb_test_batches, args.batch_size,
246 cuda=torch.cuda.is_available())
248 test_set = VignetteSet(problem_number,
249 args.nb_test_batches, args.batch_size,
250 cuda=torch.cuda.is_available())
252 log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
254 nb_test_errors = nb_errors(model, test_set)
256 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
258 100 * nb_test_errors / test_set.nb_samples,
263 ######################################################################