3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
39 from vignette_set import VignetteSet, CompressedVignetteSet
41 ######################################################################
43 parser = argparse.ArgumentParser(
44 description = 'Simple convnet test on the SVRT.',
45 formatter_class = argparse.ArgumentDefaultsHelpFormatter
48 parser.add_argument('--nb_train_batches',
49 type = int, default = 1000,
50 help = 'How many samples for train')
52 parser.add_argument('--nb_test_batches',
53 type = int, default = 100,
54 help = 'How many samples for test')
56 parser.add_argument('--nb_epochs',
57 type = int, default = 50,
58 help = 'How many training epochs')
60 parser.add_argument('--batch_size',
61 type = int, default = 100,
62 help = 'Mini-batch size')
64 parser.add_argument('--log_file',
65 type = str, default = 'cnn-svrt.log',
66 help = 'Log file name')
68 parser.add_argument('--compress_vignettes',
69 action='store_true', default = False,
70 help = 'Use lossless compression to reduce the memory footprint')
72 args = parser.parse_args()
74 ######################################################################
76 log_file = open(args.log_file, 'w')
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
81 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82 log_file.write(s + '\n')
86 ######################################################################
91 # ----------------------
93 # -- conv(21x21 x 6) -> 108x108 6
94 # -- max(2x2) -> 54x54 6
95 # -- conv(19x19 x 16) -> 36x36 16
96 # -- max(2x2) -> 18x18 16
97 # -- conv(18x18 x 120) -> 1x1 120
99 # -- full(120x84) -> 84 1
100 # -- full(84x2) -> 2 1
102 class AfrozeShallowNet(nn.Module):
104 super(AfrozeShallowNet, self).__init__()
105 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
106 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
107 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
108 self.fc1 = nn.Linear(120, 84)
109 self.fc2 = nn.Linear(84, 2)
111 def forward(self, x):
112 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
113 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
114 x = fn.relu(self.conv3(x))
116 x = fn.relu(self.fc1(x))
120 def train_model(model, train_set):
121 batch_size = args.batch_size
122 criterion = nn.CrossEntropyLoss()
124 if torch.cuda.is_available():
127 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
129 for e in range(0, args.nb_epochs):
131 for b in range(0, train_set.nb_batches):
132 input, target = train_set.get_batch(b)
133 output = model.forward(Variable(input))
134 loss = criterion(output, Variable(target))
135 acc_loss = acc_loss + loss.data[0]
139 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
143 ######################################################################
145 def nb_errors(model, data_set):
147 for b in range(0, data_set.nb_batches):
148 input, target = data_set.get_batch(b)
149 output = model.forward(Variable(input))
150 wta_prediction = output.data.max(1)[1].view(-1)
152 for i in range(0, data_set.batch_size):
153 if wta_prediction[i] != target[i]:
158 ######################################################################
160 for arg in vars(args):
161 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
163 for problem_number in range(1, 24):
164 if args.compress_vignettes:
165 train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
166 test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
168 train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
169 test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
171 model = AfrozeShallowNet()
173 if torch.cuda.is_available():
177 for p in model.parameters():
178 nb_parameters += p.numel()
179 log_string('nb_parameters {:d}'.format(nb_parameters))
181 model_filename = 'model_' + str(problem_number) + '.param'
184 model.load_state_dict(torch.load(model_filename))
185 log_string('loaded_model ' + model_filename)
187 log_string('training_model')
188 train_model(model, train_set)
189 torch.save(model.state_dict(), model_filename)
190 log_string('saved_model ' + model_filename)
192 nb_train_errors = nb_errors(model, train_set)
194 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
196 100 * nb_train_errors / train_set.nb_samples,
198 train_set.nb_samples)
201 nb_test_errors = nb_errors(model, test_set)
203 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
205 100 * nb_test_errors / test_set.nb_samples,
210 ######################################################################