3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
28 from colorama import Fore, Back, Style
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
41 ######################################################################
43 parser = argparse.ArgumentParser(
44 description = 'Simple convnet test on the SVRT.',
45 formatter_class = argparse.ArgumentDefaultsHelpFormatter
48 parser.add_argument('--nb_train_batches',
49 type = int, default = 1000,
50 help = 'How many samples for train')
52 parser.add_argument('--nb_test_batches',
53 type = int, default = 100,
54 help = 'How many samples for test')
56 parser.add_argument('--nb_epochs',
57 type = int, default = 50,
58 help = 'How many training epochs')
60 parser.add_argument('--batch_size',
61 type = int, default = 100,
62 help = 'Mini-batch size')
64 parser.add_argument('--log_file',
65 type = str, default = 'cnn-svrt.log',
66 help = 'Log file name')
68 parser.add_argument('--compress_vignettes',
69 action='store_true', default = False,
70 help = 'Use lossless compression to reduce the memory footprint')
72 args = parser.parse_args()
74 ######################################################################
76 log_file = open(args.log_file, 'w')
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
81 s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82 log_file.write(s + '\n')
86 ######################################################################
89 def __init__(self, problem_number, nb_batches):
90 self.batch_size = args.batch_size
91 self.problem_number = problem_number
92 self.nb_batches = nb_batches
93 self.nb_samples = self.nb_batches * self.batch_size
100 for b in range(0, self.nb_batches):
101 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
102 input = svrt.generate_vignettes(problem_number, target)
103 input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
104 if torch.cuda.is_available():
106 target = target.cuda()
107 acc += input.sum() / input.numel()
108 acc_sq += input.pow(2).sum() / input.numel()
109 self.targets.append(target)
110 self.inputs.append(input)
112 mean = acc / self.nb_batches
113 std = math.sqrt(acc_sq / self.nb_batches - mean * mean)
114 for b in range(0, self.nb_batches):
115 self.inputs[b].sub_(mean).div_(std)
117 def get_batch(self, b):
118 return self.inputs[b], self.targets[b]
120 ######################################################################
122 class CompressedVignetteSet:
123 def __init__(self, problem_number, nb_batches):
124 self.batch_size = args.batch_size
125 self.problem_number = problem_number
126 self.nb_batches = nb_batches
127 self.nb_samples = self.nb_batches * self.batch_size
129 self.input_storages = []
133 for b in range(0, self.nb_batches):
134 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
135 input = svrt.generate_vignettes(problem_number, target)
136 acc += input.float().sum() / input.numel()
137 acc_sq += input.float().pow(2).sum() / input.numel()
138 self.targets.append(target)
139 self.input_storages.append(svrt.compress(input.storage()))
141 self.mean = acc / self.nb_batches
142 self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
144 def get_batch(self, b):
145 input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
146 input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
147 target = self.targets[b]
149 if torch.cuda.is_available():
151 target = target.cuda()
155 ######################################################################
157 # Afroze's ShallowNet
160 # ----------------------
162 # -- conv(21x21 x 6) -> 108x108 6
163 # -- max(2x2) -> 54x54 6
164 # -- conv(19x19 x 16) -> 36x36 16
165 # -- max(2x2) -> 18x18 16
166 # -- conv(18x18 x 120) -> 1x1 120
167 # -- reshape -> 120 1
168 # -- full(120x84) -> 84 1
169 # -- full(84x2) -> 2 1
171 class AfrozeShallowNet(nn.Module):
173 super(AfrozeShallowNet, self).__init__()
174 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
175 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
176 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
177 self.fc1 = nn.Linear(120, 84)
178 self.fc2 = nn.Linear(84, 2)
180 def forward(self, x):
181 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
182 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
183 x = fn.relu(self.conv3(x))
185 x = fn.relu(self.fc1(x))
189 def train_model(model, train_set):
190 batch_size = args.batch_size
191 criterion = nn.CrossEntropyLoss()
193 if torch.cuda.is_available():
196 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
198 for e in range(0, args.nb_epochs):
200 for b in range(0, train_set.nb_batches):
201 input, target = train_set.get_batch(b)
202 output = model.forward(Variable(input))
203 loss = criterion(output, Variable(target))
204 acc_loss = acc_loss + loss.data[0]
208 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
212 ######################################################################
214 def nb_errors(model, data_set):
216 for b in range(0, data_set.nb_batches):
217 input, target = data_set.get_batch(b)
218 output = model.forward(Variable(input))
219 wta_prediction = output.data.max(1)[1].view(-1)
221 for i in range(0, data_set.batch_size):
222 if wta_prediction[i] != target[i]:
227 ######################################################################
229 for arg in vars(args):
230 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
232 for problem_number in range(1, 24):
233 if args.compress_vignettes:
234 train_set = CompressedVignetteSet(problem_number, args.nb_train_batches)
235 test_set = CompressedVignetteSet(problem_number, args.nb_test_batches)
237 train_set = VignetteSet(problem_number, args.nb_train_batches)
238 test_set = VignetteSet(problem_number, args.nb_test_batches)
240 model = AfrozeShallowNet()
242 if torch.cuda.is_available():
246 for p in model.parameters():
247 nb_parameters += p.numel()
248 log_string('nb_parameters {:d}'.format(nb_parameters))
250 model_filename = 'model_' + str(problem_number) + '.param'
253 model.load_state_dict(torch.load(model_filename))
254 log_string('loaded_model ' + model_filename)
256 log_string('training_model')
257 train_model(model, train_set)
258 torch.save(model.state_dict(), model_filename)
259 log_string('saved_model ' + model_filename)
261 nb_train_errors = nb_errors(model, train_set)
263 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
265 100 * nb_train_errors / train_set.nb_samples,
267 train_set.nb_samples)
270 nb_test_errors = nb_errors(model, test_set)
272 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
274 100 * nb_test_errors / test_set.nb_samples,
279 ######################################################################