3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
41 ######################################################################
43 parser = argparse.ArgumentParser(
44 description = 'Simple convnet test on the SVRT.',
45 formatter_class = argparse.ArgumentDefaultsHelpFormatter
48 parser.add_argument('--nb_train_batches',
49 type = int, default = 1000,
50 help = 'How many samples for train')
52 parser.add_argument('--nb_test_batches',
53 type = int, default = 100,
54 help = 'How many samples for test')
56 parser.add_argument('--nb_epochs',
57 type = int, default = 50,
58 help = 'How many training epochs')
60 parser.add_argument('--batch_size',
61 type = int, default = 100,
62 help = 'Mini-batch size')
64 parser.add_argument('--log_file',
65 type = str, default = 'cnn-svrt.log',
66 help = 'Log file name')
68 parser.add_argument('--compress_vignettes',
69 action='store_true', default = False,
70 help = 'Should we use lossless compression of vignette to reduce the memory footprint')
72 args = parser.parse_args()
74 ######################################################################
76 log_file = open(args.log_file, 'w')
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
81 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82 log_file.write(s + '\n')
86 ######################################################################
89 def __init__(self, problem_number, nb_batches):
90 self.batch_size = args.batch_size
91 self.problem_number = problem_number
92 self.nb_batches = nb_batches
93 self.nb_samples = self.nb_batches * self.batch_size
100 for k in range(0, self.nb_batches):
101 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
102 input = svrt.generate_vignettes(problem_number, target)
103 input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
104 if torch.cuda.is_available():
106 target = target.cuda()
107 acc += input.float().sum() / input.numel()
108 acc_sq += input.float().pow(2).sum() / input.numel()
109 self.targets.append(target)
110 self.inputs.append(input)
112 mean = acc / self.nb_batches
113 std = math.sqrt(acc_sq / self.nb_batches - mean * mean)
114 for k in range(0, self.nb_batches):
115 self.inputs[k].sub_(mean).div_(std)
117 def get_batch(self, b):
118 return self.inputs[b], self.targets[b]
120 class CompressedVignetteSet:
121 def __init__(self, problem_number, nb_batches):
122 self.batch_size = args.batch_size
123 self.problem_number = problem_number
124 self.nb_batches = nb_batches
125 self.nb_samples = self.nb_batches * self.batch_size
127 self.input_storages = []
131 for k in range(0, self.nb_batches):
132 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
133 input = svrt.generate_vignettes(problem_number, target)
134 acc += input.float().sum() / input.numel()
135 acc_sq += input.float().pow(2).sum() / input.numel()
136 self.targets.append(target)
137 self.input_storages.append(svrt.compress(input.storage()))
139 self.mean = acc / self.nb_batches
140 self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
142 def get_batch(self, b):
143 input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
144 input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
145 target = self.targets[b]
147 if torch.cuda.is_available():
149 target = target.cuda()
153 ######################################################################
155 # Afroze's ShallowNet
158 # ----------------------
160 # -- conv(21x21 x 6) -> 108x108 6
161 # -- max(2x2) -> 54x54 6
162 # -- conv(19x19 x 16) -> 36x36 16
163 # -- max(2x2) -> 18x18 16
164 # -- conv(18x18 x 120) -> 1x1 120
165 # -- reshape -> 120 1
166 # -- full(120x84) -> 84 1
167 # -- full(84x2) -> 2 1
169 class AfrozeShallowNet(nn.Module):
171 super(AfrozeShallowNet, self).__init__()
172 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
173 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
174 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
175 self.fc1 = nn.Linear(120, 84)
176 self.fc2 = nn.Linear(84, 2)
178 def forward(self, x):
179 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
180 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
181 x = fn.relu(self.conv3(x))
183 x = fn.relu(self.fc1(x))
187 def train_model(model, train_set):
188 batch_size = args.batch_size
189 criterion = nn.CrossEntropyLoss()
191 if torch.cuda.is_available():
194 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
196 for k in range(0, args.nb_epochs):
198 for b in range(0, train_set.nb_batches):
199 input, target = train_set.get_batch(b)
200 output = model.forward(Variable(input))
201 loss = criterion(output, Variable(target))
202 acc_loss = acc_loss + loss.data[0]
206 log_string('train_loss {:d} {:f}'.format(k, acc_loss))
210 ######################################################################
212 def nb_errors(model, data_set):
214 for b in range(0, data_set.nb_batches):
215 input, target = data_set.get_batch(b)
216 output = model.forward(Variable(input))
217 wta_prediction = output.data.max(1)[1].view(-1)
219 for i in range(0, data_set.batch_size):
220 if wta_prediction[i] != target[i]:
225 ######################################################################
227 for arg in vars(args):
228 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
230 for problem_number in range(1, 24):
231 if args.compress_vignettes:
232 train_set = CompressedVignetteSet(problem_number, args.nb_train_batches)
233 test_set = CompressedVignetteSet(problem_number, args.nb_test_batches)
235 train_set = VignetteSet(problem_number, args.nb_train_batches)
236 test_set = VignetteSet(problem_number, args.nb_test_batches)
238 model = AfrozeShallowNet()
240 if torch.cuda.is_available():
244 for p in model.parameters():
245 nb_parameters += p.numel()
246 log_string('nb_parameters {:d}'.format(nb_parameters))
248 model_filename = 'model_' + str(problem_number) + '.param'
251 model.load_state_dict(torch.load(model_filename))
252 log_string('loaded_model ' + model_filename)
254 log_string('training_model')
255 train_model(model, train_set)
256 torch.save(model.state_dict(), model_filename)
257 log_string('saved_model ' + model_filename)
259 nb_train_errors = nb_errors(model, train_set)
261 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
263 100 * nb_train_errors / train_set.nb_samples,
265 train_set.nb_samples)
268 nb_test_errors = nb_errors(model, test_set)
270 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
272 100 * nb_test_errors / test_set.nb_samples,
277 ######################################################################