3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
43 from vignette_set import VignetteSet, CompressedVignetteSet
45 ######################################################################
47 parser = argparse.ArgumentParser(
48 description = 'Simple convnet test on the SVRT.',
49 formatter_class = argparse.ArgumentDefaultsHelpFormatter
52 parser.add_argument('--nb_train_batches',
53 type = int, default = 1000,
54 help = 'How many samples for train')
56 parser.add_argument('--nb_test_batches',
57 type = int, default = 100,
58 help = 'How many samples for test')
60 parser.add_argument('--nb_epochs',
61 type = int, default = 50,
62 help = 'How many training epochs')
64 parser.add_argument('--batch_size',
65 type = int, default = 100,
66 help = 'Mini-batch size')
68 parser.add_argument('--log_file',
69 type = str, default = 'cnn-svrt.log',
70 help = 'Log file name')
72 parser.add_argument('--compress_vignettes',
73 action='store_true', default = False,
74 help = 'Use lossless compression to reduce the memory footprint')
76 args = parser.parse_args()
78 ######################################################################
80 log_file = open(args.log_file, 'w')
82 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
85 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
86 log_file.write(s + '\n')
90 ######################################################################
95 # ----------------------
97 # -- conv(21x21 x 6) -> 108x108 6
98 # -- max(2x2) -> 54x54 6
99 # -- conv(19x19 x 16) -> 36x36 16
100 # -- max(2x2) -> 18x18 16
101 # -- conv(18x18 x 120) -> 1x1 120
102 # -- reshape -> 120 1
103 # -- full(120x84) -> 84 1
104 # -- full(84x2) -> 2 1
106 class AfrozeShallowNet(nn.Module):
108 super(AfrozeShallowNet, self).__init__()
109 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
110 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
111 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
112 self.fc1 = nn.Linear(120, 84)
113 self.fc2 = nn.Linear(84, 2)
114 self.name = 'shallownet'
116 def forward(self, x):
117 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
118 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
119 x = fn.relu(self.conv3(x))
121 x = fn.relu(self.fc1(x))
125 ######################################################################
127 def train_model(model, train_set):
128 batch_size = args.batch_size
129 criterion = nn.CrossEntropyLoss()
131 if torch.cuda.is_available():
134 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
136 for e in range(0, args.nb_epochs):
138 for b in range(0, train_set.nb_batches):
139 input, target = train_set.get_batch(b)
140 output = model.forward(Variable(input))
141 loss = criterion(output, Variable(target))
142 acc_loss = acc_loss + loss.data[0]
146 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
150 ######################################################################
152 def nb_errors(model, data_set):
154 for b in range(0, data_set.nb_batches):
155 input, target = data_set.get_batch(b)
156 output = model.forward(Variable(input))
157 wta_prediction = output.data.max(1)[1].view(-1)
159 for i in range(0, data_set.batch_size):
160 if wta_prediction[i] != target[i]:
165 ######################################################################
167 for arg in vars(args):
168 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
170 ######################################################################
172 for problem_number in range(1, 24):
174 model = AfrozeShallowNet()
176 if torch.cuda.is_available():
179 model_filename = model.name + '_' + \
180 str(problem_number) + '_' + \
181 str(args.nb_train_batches) + '.param'
184 for p in model.parameters(): nb_parameters += p.numel()
185 log_string('nb_parameters {:d}'.format(nb_parameters))
187 need_to_train = False
189 model.load_state_dict(torch.load(model_filename))
190 log_string('loaded_model ' + model_filename)
196 log_string('training_model ' + model_filename)
200 if args.compress_vignettes:
201 train_set = CompressedVignetteSet(problem_number,
202 args.nb_train_batches, args.batch_size,
203 cuda=torch.cuda.is_available())
204 test_set = CompressedVignetteSet(problem_number,
205 args.nb_test_batches, args.batch_size,
206 cuda=torch.cuda.is_available())
208 train_set = VignetteSet(problem_number,
209 args.nb_train_batches, args.batch_size,
210 cuda=torch.cuda.is_available())
211 test_set = VignetteSet(problem_number,
212 args.nb_test_batches, args.batch_size,
213 cuda=torch.cuda.is_available())
215 log_string('data_generation {:0.2f} samples / s'.format(
216 (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
219 train_model(model, train_set)
220 torch.save(model.state_dict(), model_filename)
221 log_string('saved_model ' + model_filename)
223 nb_train_errors = nb_errors(model, train_set)
225 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
227 100 * nb_train_errors / train_set.nb_samples,
229 train_set.nb_samples)
232 nb_test_errors = nb_errors(model, test_set)
234 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
236 100 * nb_test_errors / test_set.nb_samples,
241 ######################################################################