3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'cnn-svrt.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 parser.add_argument('--test_loaded_models',
77 action='store_true', default = False,
78 help = 'Should we compute the test error of models we load')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
93 if pred_log_t is None:
96 elapsed = '+{:.02f}s'.format(t - pred_log_t)
98 s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
99 log_file.write(s + '\n')
103 ######################################################################
105 # Afroze's ShallowNet
108 # ----------------------
110 # -- conv(21x21 x 6) -> 108x108 6
111 # -- max(2x2) -> 54x54 6
112 # -- conv(19x19 x 16) -> 36x36 16
113 # -- max(2x2) -> 18x18 16
114 # -- conv(18x18 x 120) -> 1x1 120
115 # -- reshape -> 120 1
116 # -- full(120x84) -> 84 1
117 # -- full(84x2) -> 2 1
119 class AfrozeShallowNet(nn.Module):
121 super(AfrozeShallowNet, self).__init__()
122 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
123 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
124 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
125 self.fc1 = nn.Linear(120, 84)
126 self.fc2 = nn.Linear(84, 2)
127 self.name = 'shallownet'
129 def forward(self, x):
130 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
131 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
132 x = fn.relu(self.conv3(x))
134 x = fn.relu(self.fc1(x))
138 ######################################################################
140 def train_model(model, train_set):
141 batch_size = args.batch_size
142 criterion = nn.CrossEntropyLoss()
144 if torch.cuda.is_available():
147 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
149 start_t = time.time()
151 for e in range(0, args.nb_epochs):
153 for b in range(0, train_set.nb_batches):
154 input, target = train_set.get_batch(b)
155 output = model.forward(Variable(input))
156 loss = criterion(output, Variable(target))
157 acc_loss = acc_loss + loss.data[0]
161 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
162 dt = (time.time() - t) / (e + 1)
163 print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL)
167 ######################################################################
169 def nb_errors(model, data_set):
171 for b in range(0, data_set.nb_batches):
172 input, target = data_set.get_batch(b)
173 output = model.forward(Variable(input))
174 wta_prediction = output.data.max(1)[1].view(-1)
176 for i in range(0, data_set.batch_size):
177 if wta_prediction[i] != target[i]:
182 ######################################################################
184 for arg in vars(args):
185 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
187 ######################################################################
189 for problem_number in range(1, 24):
191 log_string('**** problem ' + str(problem_number) + ' ****')
193 model = AfrozeShallowNet()
195 if torch.cuda.is_available():
198 model_filename = model.name + '_' + \
199 str(problem_number) + '_' + \
200 str(args.nb_train_batches) + '.param'
203 for p in model.parameters(): nb_parameters += p.numel()
204 log_string('nb_parameters {:d}'.format(nb_parameters))
206 need_to_train = False
208 model.load_state_dict(torch.load(model_filename))
209 log_string('loaded_model ' + model_filename)
215 log_string('training_model ' + model_filename)
219 if args.compress_vignettes:
220 train_set = CompressedVignetteSet(problem_number,
221 args.nb_train_batches, args.batch_size,
222 cuda=torch.cuda.is_available())
224 train_set = VignetteSet(problem_number,
225 args.nb_train_batches, args.batch_size,
226 cuda=torch.cuda.is_available())
228 log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
230 train_model(model, train_set)
231 torch.save(model.state_dict(), model_filename)
232 log_string('saved_model ' + model_filename)
234 nb_train_errors = nb_errors(model, train_set)
236 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
238 100 * nb_train_errors / train_set.nb_samples,
240 train_set.nb_samples)
243 if need_to_train or args.test_loaded_models:
247 if args.compress_vignettes:
248 test_set = CompressedVignetteSet(problem_number,
249 args.nb_test_batches, args.batch_size,
250 cuda=torch.cuda.is_available())
252 test_set = VignetteSet(problem_number,
253 args.nb_test_batches, args.batch_size,
254 cuda=torch.cuda.is_available())
256 log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
258 nb_test_errors = nb_errors(model, test_set)
260 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
262 100 * nb_test_errors / test_set.nb_samples,
267 ######################################################################