2 # svrt is the ``Synthetic Visual Reasoning Test'', an image
3 # generator for evaluating classification performance of machine
4 # learning systems, humans and primates.
6 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
7 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 # This file is part of svrt.
11 # svrt is free software: you can redistribute it and/or modify it
12 # under the terms of the GNU General Public License version 3 as
13 # published by the Free Software Foundation.
15 # svrt is distributed in the hope that it will be useful, but
16 # WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 # General Public License for more details.
20 # You should have received a copy of the GNU General Public License
21 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
25 from torch import multiprocessing
27 from torch import Tensor
28 from torch.autograd import Variable
35 ######################################################################
37 def generate_one_batch(s):
38 problem_number, batch_size, random_seed = s
39 svrt.seed(random_seed)
40 target = torch.LongTensor(batch_size).bernoulli_(0.5)
41 input = svrt.generate_vignettes(problem_number, target)
42 input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
43 return [ input, target ]
47 def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
49 if nb_samples%batch_size > 0:
50 print('nb_samples must be a multiple of batch_size')
54 self.problem_number = problem_number
56 self.batch_size = batch_size
57 self.nb_samples = nb_samples
58 self.nb_batches = self.nb_samples // self.batch_size
60 seeds = torch.LongTensor(self.nb_batches).random_()
62 for b in range(0, self.nb_batches):
63 mp_args.append( [ problem_number, batch_size, seeds[b] ])
66 for b in range(0, self.nb_batches):
67 self.data.append(generate_one_batch(mp_args[b]))
68 if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
70 # Weird thing going on with the multi-processing, waiting for more info
72 # pool = multiprocessing.Pool(multiprocessing.cpu_count())
73 # self.data = pool.map(generate_one_batch, mp_args)
77 for b in range(0, self.nb_batches):
78 input = self.data[b][0]
79 acc += input.sum() / input.numel()
80 acc_sq += input.pow(2).sum() / input.numel()
82 mean = acc / self.nb_batches
83 std = sqrt(acc_sq / self.nb_batches - mean * mean)
84 for b in range(0, self.nb_batches):
85 self.data[b][0].sub_(mean).div_(std)
87 self.data[b][0] = self.data[b][0].cuda()
88 self.data[b][1] = self.data[b][1].cuda()
90 def get_batch(self, b):
93 ######################################################################
95 class CompressedVignetteSet:
96 def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
98 if nb_samples%batch_size > 0:
99 print('nb_samples must be a multiple of batch_size')
103 self.problem_number = problem_number
105 self.batch_size = batch_size
106 self.nb_samples = nb_samples
107 self.nb_batches = self.nb_samples // self.batch_size
110 self.input_storages = []
114 usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
115 for b in range(0, self.nb_batches):
116 target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
117 input = svrt.generate_vignettes(problem_number, target)
119 # FIXME input_as_float should not be necessary but there
120 # are weird memory leaks going on, which do not seem to be
123 input_as_float = input.float()
125 input_as_float.copy_(input)
126 acc += input_as_float.sum() / input.numel()
127 acc_sq += input_as_float.pow(2).sum() / input.numel()
129 self.targets.append(target)
130 self.input_storages.append(svrt.compress(input.storage()))
131 if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
134 if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
135 print('Memory leak?!')
138 mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024
139 print('Using {:.02f}Gb total {:.02f}b / samples'
140 .format(mem / (1024 * 1024 * 1024), mem / self.nb_samples))
142 self.mean = acc / self.nb_batches
143 self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
145 def get_batch(self, b):
146 input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
147 input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
148 target = self.targets[b]
152 target = target.cuda()
156 ######################################################################