3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'cnn-svrt.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 args = parser.parse_args()
78 ######################################################################
80 log_file = open(args.log_file, 'w')
83 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 if pred_log_t is None:
92 elapsed = '+{:.02f}s'.format(t - pred_log_t)
94 s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
95 log_file.write(s + '\n')
99 ######################################################################
101 # Afroze's ShallowNet
104 # ----------------------
106 # -- conv(21x21 x 6) -> 108x108 6
107 # -- max(2x2) -> 54x54 6
108 # -- conv(19x19 x 16) -> 36x36 16
109 # -- max(2x2) -> 18x18 16
110 # -- conv(18x18 x 120) -> 1x1 120
111 # -- reshape -> 120 1
112 # -- full(120x84) -> 84 1
113 # -- full(84x2) -> 2 1
115 class AfrozeShallowNet(nn.Module):
117 super(AfrozeShallowNet, self).__init__()
118 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
119 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
120 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
121 self.fc1 = nn.Linear(120, 84)
122 self.fc2 = nn.Linear(84, 2)
123 self.name = 'shallownet'
125 def forward(self, x):
126 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
127 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
128 x = fn.relu(self.conv3(x))
130 x = fn.relu(self.fc1(x))
134 ######################################################################
136 def train_model(model, train_set):
137 batch_size = args.batch_size
138 criterion = nn.CrossEntropyLoss()
140 if torch.cuda.is_available():
143 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
145 for e in range(0, args.nb_epochs):
147 for b in range(0, train_set.nb_batches):
148 input, target = train_set.get_batch(b)
149 output = model.forward(Variable(input))
150 loss = criterion(output, Variable(target))
151 acc_loss = acc_loss + loss.data[0]
155 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
159 ######################################################################
161 def nb_errors(model, data_set):
163 for b in range(0, data_set.nb_batches):
164 input, target = data_set.get_batch(b)
165 output = model.forward(Variable(input))
166 wta_prediction = output.data.max(1)[1].view(-1)
168 for i in range(0, data_set.batch_size):
169 if wta_prediction[i] != target[i]:
174 ######################################################################
176 for arg in vars(args):
177 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
179 ######################################################################
181 for problem_number in range(1, 24):
183 log_string('**** problem ' + str(problem_number) + ' ****')
185 model = AfrozeShallowNet()
187 if torch.cuda.is_available():
190 model_filename = model.name + '_' + \
191 str(problem_number) + '_' + \
192 str(args.nb_train_batches) + '.param'
195 for p in model.parameters(): nb_parameters += p.numel()
196 log_string('nb_parameters {:d}'.format(nb_parameters))
198 need_to_train = False
200 model.load_state_dict(torch.load(model_filename))
201 log_string('loaded_model ' + model_filename)
207 log_string('training_model ' + model_filename)
211 if args.compress_vignettes:
212 train_set = CompressedVignetteSet(problem_number,
213 args.nb_train_batches, args.batch_size,
214 cuda=torch.cuda.is_available())
216 train_set = VignetteSet(problem_number,
217 args.nb_train_batches, args.batch_size,
218 cuda=torch.cuda.is_available())
220 log_string('data_generation {:0.2f} samples / s'.format(
221 (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
224 train_model(model, train_set)
225 torch.save(model.state_dict(), model_filename)
226 log_string('saved_model ' + model_filename)
228 nb_train_errors = nb_errors(model, train_set)
230 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
232 100 * nb_train_errors / train_set.nb_samples,
234 train_set.nb_samples)
237 if args.compress_vignettes:
238 test_set = CompressedVignetteSet(problem_number,
239 args.nb_test_batches, args.batch_size,
240 cuda=torch.cuda.is_available())
242 test_set = VignetteSet(problem_number,
243 args.nb_test_batches, args.batch_size,
244 cuda=torch.cuda.is_available())
246 nb_test_errors = nb_errors(model, test_set)
248 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
250 100 * nb_test_errors / test_set.nb_samples,
256 ######################################################################