Moved VignetteSet and CompressedVignetteSet in their own file.
[pysvrt.git] / vignette_set.py
1
2 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
3 #  generator for evaluating classification performance of machine
4 #  learning systems, humans and primates.
5 #
6 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
7 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8 #
9 #  This file is part of svrt.
10 #
11 #  svrt is free software: you can redistribute it and/or modify it
12 #  under the terms of the GNU General Public License version 3 as
13 #  published by the Free Software Foundation.
14 #
15 #  svrt is distributed in the hope that it will be useful, but
16 #  WITHOUT ANY WARRANTY; without even the implied warranty of
17 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18 #  General Public License for more details.
19 #
20 #  You should have received a copy of the GNU General Public License
21 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22
23 import torch
24 from math import sqrt
25
26 from torch import Tensor
27 from torch.autograd import Variable
28
29 import svrt
30
31 ######################################################################
32
33 class VignetteSet:
34     def __init__(self, problem_number, nb_batches, batch_size):
35         self.batch_size = batch_size
36         self.problem_number = problem_number
37         self.nb_batches = nb_batches
38         self.nb_samples = self.nb_batches * self.batch_size
39         self.targets = []
40         self.inputs = []
41
42         acc = 0.0
43         acc_sq = 0.0
44
45         for b in range(0, self.nb_batches):
46             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
47             input = svrt.generate_vignettes(problem_number, target)
48             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
49             if torch.cuda.is_available():
50                 input = input.cuda()
51                 target = target.cuda()
52             acc += input.sum() / input.numel()
53             acc_sq += input.pow(2).sum() /  input.numel()
54             self.targets.append(target)
55             self.inputs.append(input)
56
57         mean = acc / self.nb_batches
58         std = sqrt(acc_sq / self.nb_batches - mean * mean)
59         for b in range(0, self.nb_batches):
60             self.inputs[b].sub_(mean).div_(std)
61
62     def get_batch(self, b):
63         return self.inputs[b], self.targets[b]
64
65 ######################################################################
66
67 class CompressedVignetteSet:
68     def __init__(self, problem_number, nb_batches, batch_size):
69         self.batch_size = batch_size
70         self.problem_number = problem_number
71         self.nb_batches = nb_batches
72         self.nb_samples = self.nb_batches * self.batch_size
73         self.targets = []
74         self.input_storages = []
75
76         acc = 0.0
77         acc_sq = 0.0
78         for b in range(0, self.nb_batches):
79             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
80             input = svrt.generate_vignettes(problem_number, target)
81             acc += input.float().sum() / input.numel()
82             acc_sq += input.float().pow(2).sum() /  input.numel()
83             self.targets.append(target)
84             self.input_storages.append(svrt.compress(input.storage()))
85
86         self.mean = acc / self.nb_batches
87         self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
88
89     def get_batch(self, b):
90         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
91         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
92         target = self.targets[b]
93
94         if torch.cuda.is_available():
95             input = input.cuda()
96             target = target.cuda()
97
98         return input, target
99
100 ######################################################################