1 #!/usr/bin/env python-for-pytorch
3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from torch import optim
29 from torch import FloatTensor as Tensor
30 from torch.autograd import Variable
32 from torch.nn import functional as fn
33 from torchvision import datasets, transforms, utils
37 ######################################################################
39 def generate_set(p, n):
40 target = torch.LongTensor(n).bernoulli_(0.5)
41 input = svrt.generate_vignettes(p, target)
42 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
43 return Variable(input), Variable(target)
45 ######################################################################
47 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
51 super(Net, self).__init__()
52 self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
53 self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
54 self.fc1 = nn.Linear(500, 100)
55 self.fc2 = nn.Linear(100, 2)
58 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
59 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
61 x = fn.relu(self.fc1(x))
65 def train_model(train_input, train_target):
66 model, criterion = Net(), nn.CrossEntropyLoss()
68 if torch.cuda.is_available():
73 optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
75 for k in range(0, nb_epochs):
76 for b in range(0, nb_train_samples, bs):
77 output = model.forward(train_input.narrow(0, b, bs))
78 loss = criterion(output, train_target.narrow(0, b, bs))
85 ######################################################################
87 def print_test_error(model, test_input, test_target):
91 for b in range(0, nb_test_samples, bs):
92 output = model.forward(test_input.narrow(0, b, bs))
93 _, wta = torch.max(output.data, 1)
95 for i in range(0, bs):
96 if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
97 nb_test_errors = nb_test_errors + 1
99 print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
100 100 * nb_test_errors / nb_test_samples,
105 ######################################################################
107 nb_train_samples = 100000
108 nb_test_samples = 10000
110 for p in range(1, 24):
111 print('-- PROBLEM #{:d} --'.format(p))
114 train_input, train_target = generate_set(p, nb_train_samples)
115 test_input, test_target = generate_set(p, nb_test_samples)
116 if torch.cuda.is_available():
117 train_input, train_target = train_input.cuda(), train_target.cuda()
118 test_input, test_target = test_input.cuda(), test_target.cuda()
120 mu, std = train_input.data.mean(), train_input.data.std()
121 train_input.data.sub_(mu).div_(std)
122 test_input.data.sub_(mu).div_(std)
125 print('[data generation {:.02f}s]'.format(t2 - t1))
126 model = train_model(train_input, train_target)
129 print('[train {:.02f}s]'.format(t3 - t2))
130 print_test_error(model, test_input, test_target)
134 print('[test {:.02f}s]'.format(t4 - t3))
137 ######################################################################