3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
26 from colorama import Fore, Back, Style
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
34 from torch.nn import functional as fn
35 from torchvision import datasets, transforms, utils
39 ######################################################################
41 parser = argparse.ArgumentParser(
42 description = 'Simple convnet test on the SVRT.',
43 formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 parser.add_argument('--nb_train_samples',
47 type = int, default = 100000,
48 help = 'How many samples for train')
50 parser.add_argument('--nb_test_samples',
51 type = int, default = 10000,
52 help = 'How many samples for test')
54 parser.add_argument('--nb_epochs',
55 type = int, default = 25,
56 help = 'How many training epochs')
58 parser.add_argument('--log_file',
59 type = str, default = 'cnn-svrt.log',
60 help = 'Log file name')
62 args = parser.parse_args()
64 ######################################################################
66 log_file = open(args.log_file, 'w')
68 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
71 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
72 str(problem_number) + ' ' + s
73 log_file.write(s + '\n')
77 ######################################################################
79 def generate_set(p, n):
80 target = torch.LongTensor(n).bernoulli_(0.5)
82 input = svrt.generate_vignettes(p, target)
84 log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86 return Variable(input), Variable(target)
88 ######################################################################
90 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
94 super(Net, self).__init__()
95 self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
96 self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
97 self.fc1 = nn.Linear(500, 100)
98 self.fc2 = nn.Linear(100, 2)
100 def forward(self, x):
101 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
102 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
104 x = fn.relu(self.fc1(x))
108 def train_model(train_input, train_target):
109 model, criterion = Net(), nn.CrossEntropyLoss()
111 if torch.cuda.is_available():
115 optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100
117 for k in range(0, args.nb_epochs):
119 for b in range(0, train_input.size(0), bs):
120 output = model.forward(train_input.narrow(0, b, bs))
121 loss = criterion(output, train_target.narrow(0, b, bs))
122 acc_loss = acc_loss + loss.data[0]
126 log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
130 ######################################################################
132 def nb_errors(model, data_input, data_target, bs = 100):
135 for b in range(0, data_input.size(0), bs):
136 output = model.forward(data_input.narrow(0, b, bs))
137 wta_prediction = output.data.max(1)[1].view(-1)
139 for i in range(0, bs):
140 if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
145 ######################################################################
147 for problem_number in range(1, 24):
148 train_input, train_target = generate_set(problem_number, args.nb_train_samples)
149 test_input, test_target = generate_set(problem_number, args.nb_test_samples)
151 if torch.cuda.is_available():
152 train_input, train_target = train_input.cuda(), train_target.cuda()
153 test_input, test_target = test_input.cuda(), test_target.cuda()
155 mu, std = train_input.data.mean(), train_input.data.std()
156 train_input.data.sub_(mu).div_(std)
157 test_input.data.sub_(mu).div_(std)
159 model = train_model(train_input, train_target)
161 nb_train_errors = nb_errors(model, train_input, train_target)
163 log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
164 100 * nb_train_errors / train_input.size(0),
169 nb_test_errors = nb_errors(model, test_input, test_target)
171 log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
172 100 * nb_test_errors / test_input.size(0),
177 ######################################################################