3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'cnn-svrt.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 parser.add_argument('--test_loaded_models',
77 action='store_true', default = False,
78 help = 'Should we compute the test error of models we load')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
94 if pred_log_t is None:
97 elapsed = '+{:.02f}s'.format(t - pred_log_t)
101 s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
102 log_file.write(s + '\n')
106 ######################################################################
108 # Afroze's ShallowNet
111 # ----------------------
113 # -- conv(21x21 x 6) -> 108x108 6
114 # -- max(2x2) -> 54x54 6
115 # -- conv(19x19 x 16) -> 36x36 16
116 # -- max(2x2) -> 18x18 16
117 # -- conv(18x18 x 120) -> 1x1 120
118 # -- reshape -> 120 1
119 # -- full(120x84) -> 84 1
120 # -- full(84x2) -> 2 1
122 class AfrozeShallowNet(nn.Module):
124 super(AfrozeShallowNet, self).__init__()
125 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
126 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
127 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
128 self.fc1 = nn.Linear(120, 84)
129 self.fc2 = nn.Linear(84, 2)
130 self.name = 'shallownet'
132 def forward(self, x):
133 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
134 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
135 x = fn.relu(self.conv3(x))
137 x = fn.relu(self.fc1(x))
141 ######################################################################
143 def train_model(model, train_set):
144 batch_size = args.batch_size
145 criterion = nn.CrossEntropyLoss()
147 if torch.cuda.is_available():
150 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
152 start_t = time.time()
154 for e in range(0, args.nb_epochs):
156 for b in range(0, train_set.nb_batches):
157 input, target = train_set.get_batch(b)
158 output = model.forward(Variable(input))
159 loss = criterion(output, Variable(target))
160 acc_loss = acc_loss + loss.data[0]
164 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
165 dt = (time.time() - start_t) / (e + 1)
166 print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL)
170 ######################################################################
172 def nb_errors(model, data_set):
174 for b in range(0, data_set.nb_batches):
175 input, target = data_set.get_batch(b)
176 output = model.forward(Variable(input))
177 wta_prediction = output.data.max(1)[1].view(-1)
179 for i in range(0, data_set.batch_size):
180 if wta_prediction[i] != target[i]:
185 ######################################################################
187 for arg in vars(args):
188 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
190 ######################################################################
192 for problem_number in range(1, 24):
194 log_string('**** problem ' + str(problem_number) + ' ****')
196 model = AfrozeShallowNet()
198 if torch.cuda.is_available():
201 model_filename = model.name + '_' + \
202 str(problem_number) + '_' + \
203 str(args.nb_train_batches) + '.param'
206 for p in model.parameters(): nb_parameters += p.numel()
207 log_string('nb_parameters {:d}'.format(nb_parameters))
209 need_to_train = False
211 model.load_state_dict(torch.load(model_filename))
212 log_string('loaded_model ' + model_filename)
218 log_string('training_model ' + model_filename)
222 if args.compress_vignettes:
223 train_set = CompressedVignetteSet(problem_number,
224 args.nb_train_batches, args.batch_size,
225 cuda=torch.cuda.is_available())
227 train_set = VignetteSet(problem_number,
228 args.nb_train_batches, args.batch_size,
229 cuda=torch.cuda.is_available())
231 log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
233 train_model(model, train_set)
234 torch.save(model.state_dict(), model_filename)
235 log_string('saved_model ' + model_filename)
237 nb_train_errors = nb_errors(model, train_set)
239 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
241 100 * nb_train_errors / train_set.nb_samples,
243 train_set.nb_samples)
246 if need_to_train or args.test_loaded_models:
250 if args.compress_vignettes:
251 test_set = CompressedVignetteSet(problem_number,
252 args.nb_test_batches, args.batch_size,
253 cuda=torch.cuda.is_available())
255 test_set = VignetteSet(problem_number,
256 args.nb_test_batches, args.batch_size,
257 cuda=torch.cuda.is_available())
259 log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
261 nb_test_errors = nb_errors(model, test_set)
263 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
265 100 * nb_test_errors / test_set.nb_samples,
270 ######################################################################