3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
27 from colorama import Fore, Back, Style
31 from torch import optim
32 from torch import FloatTensor as Tensor
33 from torch.autograd import Variable
35 from torch.nn import functional as fn
36 from torchvision import datasets, transforms, utils
40 ######################################################################
42 parser = argparse.ArgumentParser(
43 description = 'Simple convnet test on the SVRT.',
44 formatter_class = argparse.ArgumentDefaultsHelpFormatter
47 parser.add_argument('--nb_train_samples',
48 type = int, default = 100000,
49 help = 'How many samples for train')
51 parser.add_argument('--nb_test_samples',
52 type = int, default = 10000,
53 help = 'How many samples for test')
55 parser.add_argument('--nb_epochs',
56 type = int, default = 50,
57 help = 'How many training epochs')
59 parser.add_argument('--log_file',
60 type = str, default = 'cnn-svrt.log',
61 help = 'Log file name')
63 args = parser.parse_args()
65 ######################################################################
67 log_file = open(args.log_file, 'w')
69 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
72 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
73 log_file.write(s + '\n')
77 ######################################################################
79 def generate_set(p, n):
80 target = torch.LongTensor(n).bernoulli_(0.5)
82 input = svrt.generate_vignettes(p, target)
84 log_string('data_set_generation {:.02f} sample/s'.format(n / t))
85 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86 return Variable(input), Variable(target)
88 ######################################################################
93 # ----------------------
95 # -- conv(21x21 x 6) -> 108x108 6
96 # -- max(2x2) -> 54x54 6
97 # -- conv(19x19 x 16) -> 36x36 16
98 # -- max(2x2) -> 18x18 16
99 # -- conv(18x18 x 120) -> 1x1 120
100 # -- reshape -> 120 1
101 # -- full(120x84) -> 84 1
102 # -- full(84x2) -> 2 1
104 class AfrozeShallowNet(nn.Module):
106 super(AfrozeShallowNet, self).__init__()
107 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110 self.fc1 = nn.Linear(120, 84)
111 self.fc2 = nn.Linear(84, 2)
113 def forward(self, x):
114 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116 x = fn.relu(self.conv3(x))
118 x = fn.relu(self.fc1(x))
122 def train_model(model, train_input, train_target):
123 criterion = nn.CrossEntropyLoss()
125 if torch.cuda.is_available():
128 optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
130 for k in range(0, args.nb_epochs):
132 for b in range(0, train_input.size(0), bs):
133 output = model.forward(train_input.narrow(0, b, bs))
134 loss = criterion(output, train_target.narrow(0, b, bs))
135 acc_loss = acc_loss + loss.data[0]
139 log_string('train_loss {:d} {:f}'.format(k, acc_loss))
143 ######################################################################
145 def nb_errors(model, data_input, data_target, bs = 100):
148 for b in range(0, data_input.size(0), bs):
149 output = model.forward(data_input.narrow(0, b, bs))
150 wta_prediction = output.data.max(1)[1].view(-1)
152 for i in range(0, bs):
153 if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
158 ######################################################################
160 for arg in vars(args):
161 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
163 for problem_number in range(1, 24):
164 train_input, train_target = generate_set(problem_number, args.nb_train_samples)
165 test_input, test_target = generate_set(problem_number, args.nb_test_samples)
166 model = AfrozeShallowNet()
168 if torch.cuda.is_available():
169 train_input, train_target = train_input.cuda(), train_target.cuda()
170 test_input, test_target = test_input.cuda(), test_target.cuda()
173 mu, std = train_input.data.mean(), train_input.data.std()
174 train_input.data.sub_(mu).div_(std)
175 test_input.data.sub_(mu).div_(std)
178 for p in model.parameters():
179 nb_parameters += p.numel()
180 log_string('nb_parameters {:d}'.format(nb_parameters))
182 model_filename = 'model_' + str(problem_number) + '.param'
185 model.load_state_dict(torch.load(model_filename))
186 log_string('loaded_model ' + model_filename)
188 log_string('training_model')
189 train_model(model, train_input, train_target)
190 torch.save(model.state_dict(), model_filename)
191 log_string('saved_model ' + model_filename)
193 nb_train_errors = nb_errors(model, train_input, train_target)
195 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
197 100 * nb_train_errors / train_input.size(0),
202 nb_test_errors = nb_errors(model, test_input, test_target)
204 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
206 100 * nb_test_errors / test_input.size(0),
211 ######################################################################