3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
27 from colorama import Fore, Back, Style
31 from torch import optim
32 from torch import FloatTensor as Tensor
33 from torch.autograd import Variable
35 from torch.nn import functional as fn
36 from torchvision import datasets, transforms, utils
40 ######################################################################
42 parser = argparse.ArgumentParser(
43 description = 'Simple convnet test on the SVRT.',
44 formatter_class = argparse.ArgumentDefaultsHelpFormatter
47 parser.add_argument('--nb_train_batches',
48 type = int, default = 1000,
49 help = 'How many samples for train')
51 parser.add_argument('--nb_test_batches',
52 type = int, default = 100,
53 help = 'How many samples for test')
55 parser.add_argument('--nb_epochs',
56 type = int, default = 50,
57 help = 'How many training epochs')
59 parser.add_argument('--batch_size',
60 type = int, default = 100,
61 help = 'Mini-batch size')
63 parser.add_argument('--log_file',
64 type = str, default = 'cnn-svrt.log',
65 help = 'Log file name')
67 args = parser.parse_args()
69 ######################################################################
71 log_file = open(args.log_file, 'w')
73 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
76 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
77 log_file.write(s + '\n')
81 ######################################################################
83 def generate_set(p, n):
84 target = torch.LongTensor(n).bernoulli_(0.5)
86 input = svrt.generate_vignettes(p, target)
88 log_string('data_set_generation {:.02f} sample/s'.format(n / t))
89 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
90 return Variable(input), Variable(target)
92 ######################################################################
97 # ----------------------
99 # -- conv(21x21 x 6) -> 108x108 6
100 # -- max(2x2) -> 54x54 6
101 # -- conv(19x19 x 16) -> 36x36 16
102 # -- max(2x2) -> 18x18 16
103 # -- conv(18x18 x 120) -> 1x1 120
104 # -- reshape -> 120 1
105 # -- full(120x84) -> 84 1
106 # -- full(84x2) -> 2 1
108 class AfrozeShallowNet(nn.Module):
110 super(AfrozeShallowNet, self).__init__()
111 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
112 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
113 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
114 self.fc1 = nn.Linear(120, 84)
115 self.fc2 = nn.Linear(84, 2)
117 def forward(self, x):
118 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
119 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
120 x = fn.relu(self.conv3(x))
122 x = fn.relu(self.fc1(x))
126 def train_model(model, train_input, train_target):
128 criterion = nn.CrossEntropyLoss()
130 if torch.cuda.is_available():
133 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
135 for k in range(0, args.nb_epochs):
137 for b in range(0, train_input.size(0), bs):
138 output = model.forward(train_input.narrow(0, b, bs))
139 loss = criterion(output, train_target.narrow(0, b, bs))
140 acc_loss = acc_loss + loss.data[0]
144 log_string('train_loss {:d} {:f}'.format(k, acc_loss))
148 ######################################################################
150 def nb_errors(model, data_input, data_target):
154 for b in range(0, data_input.size(0), bs):
155 output = model.forward(data_input.narrow(0, b, bs))
156 wta_prediction = output.data.max(1)[1].view(-1)
158 for i in range(0, bs):
159 if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
164 ######################################################################
166 for arg in vars(args):
167 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
169 for problem_number in range(1, 24):
170 train_input, train_target = generate_set(problem_number,
171 args.nb_train_batches * args.batch_size)
172 test_input, test_target = generate_set(problem_number,
173 args.nb_test_batches * args.batch_size)
174 model = AfrozeShallowNet()
176 if torch.cuda.is_available():
177 train_input, train_target = train_input.cuda(), train_target.cuda()
178 test_input, test_target = test_input.cuda(), test_target.cuda()
181 mu, std = train_input.data.mean(), train_input.data.std()
182 train_input.data.sub_(mu).div_(std)
183 test_input.data.sub_(mu).div_(std)
186 for p in model.parameters():
187 nb_parameters += p.numel()
188 log_string('nb_parameters {:d}'.format(nb_parameters))
190 model_filename = 'model_' + str(problem_number) + '.param'
193 model.load_state_dict(torch.load(model_filename))
194 log_string('loaded_model ' + model_filename)
196 log_string('training_model')
197 train_model(model, train_input, train_target)
198 torch.save(model.state_dict(), model_filename)
199 log_string('saved_model ' + model_filename)
201 nb_train_errors = nb_errors(model, train_input, train_target)
203 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
205 100 * nb_train_errors / train_input.size(0),
210 nb_test_errors = nb_errors(model, test_input, test_target)
212 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
214 100 * nb_test_errors / test_input.size(0),
219 ######################################################################