3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
29 from torch import optim
30 from torch import FloatTensor as Tensor
31 from torch.autograd import Variable
33 from torch.nn import functional as fn
34 from torchvision import datasets, transforms, utils
38 ######################################################################
40 parser = argparse.ArgumentParser(
41 description = 'Simple convnet test on the SVRT.',
42 formatter_class = argparse.ArgumentDefaultsHelpFormatter
45 parser.add_argument('--nb_train_samples',
46 type = int, default = 100000,
47 help = 'How many samples for train')
49 parser.add_argument('--nb_test_samples',
50 type = int, default = 10000,
51 help = 'How many samples for test')
53 parser.add_argument('--nb_epochs',
54 type = int, default = 25,
55 help = 'How many training epochs')
57 args = parser.parse_args()
59 ######################################################################
61 log_file = open('cnn-svrt.log', 'w')
64 s = time.ctime() + ' ' + str(problem_number) + ' | ' + s
65 log_file.write(s + '\n')
69 ######################################################################
71 def generate_set(p, n):
72 target = torch.LongTensor(n).bernoulli_(0.5)
73 input = svrt.generate_vignettes(p, target)
74 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
75 return Variable(input), Variable(target)
77 ######################################################################
79 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
83 super(Net, self).__init__()
84 self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
85 self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
86 self.fc1 = nn.Linear(500, 100)
87 self.fc2 = nn.Linear(100, 2)
90 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
91 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
93 x = fn.relu(self.fc1(x))
97 def train_model(train_input, train_target):
98 model, criterion = Net(), nn.CrossEntropyLoss()
100 if torch.cuda.is_available():
104 optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
106 for k in range(0, args.nb_epochs):
108 for b in range(0, train_input.size(0), bs):
109 output = model.forward(train_input.narrow(0, b, bs))
110 loss = criterion(output, train_target.narrow(0, b, bs))
111 acc_loss = acc_loss + loss.data[0]
115 log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
119 ######################################################################
121 def nb_errors(model, data_input, data_target, bs = 100):
124 for b in range(0, data_input.size(0), bs):
125 output = model.forward(data_input.narrow(0, b, bs))
126 wta_prediction = output.data.max(1)[1].view(-1)
128 for i in range(0, bs):
129 if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
134 ######################################################################
136 for problem_number in range(1, 24):
137 train_input, train_target = generate_set(problem_number, args.nb_train_samples)
138 test_input, test_target = generate_set(problem_number, args.nb_test_samples)
140 if torch.cuda.is_available():
141 train_input, train_target = train_input.cuda(), train_target.cuda()
142 test_input, test_target = test_input.cuda(), test_target.cuda()
144 mu, std = train_input.data.mean(), train_input.data.std()
145 train_input.data.sub_(mu).div_(std)
146 test_input.data.sub_(mu).div_(std)
148 model = train_model(train_input, train_target)
150 nb_test_errors = nb_errors(model, test_input, test_target)
152 log_string('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
153 100 * nb_test_errors / test_input.size(0),
158 ######################################################################