2 # svrt is the ``Synthetic Visual Reasoning Test'', an image
3 # generator for evaluating classification performance of machine
4 # learning systems, humans and primates.
6 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
7 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 # This file is part of svrt.
11 # svrt is free software: you can redistribute it and/or modify it
12 # under the terms of the GNU General Public License version 3 as
13 # published by the Free Software Foundation.
15 # svrt is distributed in the hope that it will be useful, but
16 # WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 # General Public License for more details.
20 # You should have received a copy of the GNU General Public License
21 # along with selector. If not, see <http://www.gnu.org/licenses/>.
26 from torch import Tensor
27 from torch.autograd import Variable
31 ######################################################################
34 def __init__(self, problem_number, nb_batches, batch_size):
35 self.batch_size = batch_size
36 self.problem_number = problem_number
37 self.nb_batches = nb_batches
38 self.nb_samples = self.nb_batches * self.batch_size
45 for b in range(0, self.nb_batches):
46 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
47 input = svrt.generate_vignettes(problem_number, target)
48 input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
49 if torch.cuda.is_available():
51 target = target.cuda()
52 acc += input.sum() / input.numel()
53 acc_sq += input.pow(2).sum() / input.numel()
54 self.targets.append(target)
55 self.inputs.append(input)
57 mean = acc / self.nb_batches
58 std = sqrt(acc_sq / self.nb_batches - mean * mean)
59 for b in range(0, self.nb_batches):
60 self.inputs[b].sub_(mean).div_(std)
62 def get_batch(self, b):
63 return self.inputs[b], self.targets[b]
65 ######################################################################
67 class CompressedVignetteSet:
68 def __init__(self, problem_number, nb_batches, batch_size):
69 self.batch_size = batch_size
70 self.problem_number = problem_number
71 self.nb_batches = nb_batches
72 self.nb_samples = self.nb_batches * self.batch_size
74 self.input_storages = []
78 for b in range(0, self.nb_batches):
79 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
80 input = svrt.generate_vignettes(problem_number, target)
81 acc += input.float().sum() / input.numel()
82 acc_sq += input.float().pow(2).sum() / input.numel()
83 self.targets.append(target)
84 self.input_storages.append(svrt.compress(input.storage()))
86 self.mean = acc / self.nb_batches
87 self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
89 def get_batch(self, b):
90 input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
91 input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
92 target = self.targets[b]
94 if torch.cuda.is_available():
96 target = target.cuda()
100 ######################################################################