3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
26 from colorama import Fore, Back, Style
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
34 from torch.nn import functional as fn
35 from torchvision import datasets, transforms, utils
39 ######################################################################
41 parser = argparse.ArgumentParser(
42 description = 'Simple convnet test on the SVRT.',
43 formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 parser.add_argument('--nb_train_samples',
47 type = int, default = 100000,
48 help = 'How many samples for train')
50 parser.add_argument('--nb_test_samples',
51 type = int, default = 10000,
52 help = 'How many samples for test')
54 parser.add_argument('--nb_epochs',
55 type = int, default = 25,
56 help = 'How many training epochs')
58 parser.add_argument('--log_file',
59 type = str, default = 'cnn-svrt.log',
60 help = 'Log file name')
62 args = parser.parse_args()
64 ######################################################################
66 log_file = open(args.log_file, 'w')
68 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
71 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
72 str(problem_number) + ' ' + s
73 log_file.write(s + '\n')
77 ######################################################################
79 def generate_set(p, n):
80 target = torch.LongTensor(n).bernoulli_(0.5)
82 input = svrt.generate_vignettes(p, target)
84 log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86 return Variable(input), Variable(target)
88 ######################################################################
93 # ----------------------
95 # -- conv(21x21) -> 108x108 6
96 # -- max(2x2) -> 54x54 6
97 # -- conv(19x19) -> 36x36 16
98 # -- max(2x2) -> 18x18 16
99 # -- conv(18x18) -> 1x1 120
100 # -- reshape -> 120 1
101 # -- full(120x84) -> 84 1
102 # -- full(84x2) -> 2 1
104 class Net(nn.Module):
106 super(Net, self).__init__()
107 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110 self.fc1 = nn.Linear(120, 84)
111 self.fc2 = nn.Linear(84, 2)
113 def forward(self, x):
114 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116 x = fn.relu(self.conv3(x))
118 x = fn.relu(self.fc1(x))
122 def train_model(train_input, train_target):
123 model, criterion = Net(), nn.CrossEntropyLoss()
126 for p in model.parameters():
127 nb_parameters += p.numel()
128 log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
130 if torch.cuda.is_available():
134 optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
136 for k in range(0, args.nb_epochs):
138 for b in range(0, train_input.size(0), bs):
139 output = model.forward(train_input.narrow(0, b, bs))
140 loss = criterion(output, train_target.narrow(0, b, bs))
141 acc_loss = acc_loss + loss.data[0]
145 log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
149 ######################################################################
151 def nb_errors(model, data_input, data_target, bs = 100):
154 for b in range(0, data_input.size(0), bs):
155 output = model.forward(data_input.narrow(0, b, bs))
156 wta_prediction = output.data.max(1)[1].view(-1)
158 for i in range(0, bs):
159 if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
164 ######################################################################
166 for problem_number in range(1, 24):
167 train_input, train_target = generate_set(problem_number, args.nb_train_samples)
168 test_input, test_target = generate_set(problem_number, args.nb_test_samples)
170 if torch.cuda.is_available():
171 train_input, train_target = train_input.cuda(), train_target.cuda()
172 test_input, test_target = test_input.cuda(), test_target.cuda()
174 mu, std = train_input.data.mean(), train_input.data.std()
175 train_input.data.sub_(mu).div_(std)
176 test_input.data.sub_(mu).div_(std)
178 model = train_model(train_input, train_target)
180 nb_train_errors = nb_errors(model, train_input, train_target)
182 log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
183 100 * nb_train_errors / train_input.size(0),
188 nb_test_errors = nb_errors(model, test_input, test_target)
190 log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
191 100 * nb_test_errors / test_input.size(0),
196 ######################################################################