The learning continues from where it stopped if a state is found.
[pysvrt.git] / svrtset.py
1
2 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
3 #  generator for evaluating classification performance of machine
4 #  learning systems, humans and primates.
5 #
6 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
7 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8 #
9 #  This file is part of svrt.
10 #
11 #  svrt is free software: you can redistribute it and/or modify it
12 #  under the terms of the GNU General Public License version 3 as
13 #  published by the Free Software Foundation.
14 #
15 #  svrt is distributed in the hope that it will be useful, but
16 #  WITHOUT ANY WARRANTY; without even the implied warranty of
17 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18 #  General Public License for more details.
19 #
20 #  You should have received a copy of the GNU General Public License
21 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22
23 import torch
24 from math import sqrt
25 from torch import multiprocessing
26
27 from torch import Tensor
28 from torch.autograd import Variable
29
30 import svrt
31
32 # FIXME
33 import resource
34
35 ######################################################################
36
37 def generate_one_batch(s):
38     problem_number, batch_size, random_seed = s
39     svrt.seed(random_seed)
40     target = torch.LongTensor(batch_size).bernoulli_(0.5)
41     input = svrt.generate_vignettes(problem_number, target)
42     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
43     return [ input, target ]
44
45 class VignetteSet:
46
47     def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
48
49         if nb_samples%batch_size > 0:
50             print('nb_samples must be a multiple of batch_size')
51             raise
52
53         self.cuda = cuda
54         self.problem_number = problem_number
55
56         self.batch_size = batch_size
57         self.nb_samples = nb_samples
58         self.nb_batches = self.nb_samples // self.batch_size
59
60         seeds = torch.LongTensor(self.nb_batches).random_()
61         mp_args = []
62         for b in range(0, self.nb_batches):
63             mp_args.append( [ problem_number, batch_size, seeds[b] ])
64
65         self.data = []
66         for b in range(0, self.nb_batches):
67             self.data.append(generate_one_batch(mp_args[b]))
68             if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
69
70         # Weird thing going on with the multi-processing, waiting for more info
71
72         # pool = multiprocessing.Pool(multiprocessing.cpu_count())
73         # self.data = pool.map(generate_one_batch, mp_args)
74
75         acc = 0.0
76         acc_sq = 0.0
77         for b in range(0, self.nb_batches):
78             input = self.data[b][0]
79             acc += input.sum() / input.numel()
80             acc_sq += input.pow(2).sum() /  input.numel()
81
82         mean = acc / self.nb_batches
83         std = sqrt(acc_sq / self.nb_batches - mean * mean)
84         for b in range(0, self.nb_batches):
85             self.data[b][0].sub_(mean).div_(std)
86             if cuda:
87                 self.data[b][0] = self.data[b][0].cuda()
88                 self.data[b][1] = self.data[b][1].cuda()
89
90     def get_batch(self, b):
91         return self.data[b]
92
93 ######################################################################
94
95 class CompressedVignetteSet:
96     def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
97
98         if nb_samples%batch_size > 0:
99             print('nb_samples must be a multiple of batch_size')
100             raise
101
102         self.cuda = cuda
103         self.problem_number = problem_number
104
105         self.batch_size = batch_size
106         self.nb_samples = nb_samples
107         self.nb_batches = self.nb_samples // self.batch_size
108
109         self.targets = []
110         self.input_storages = []
111
112         acc = 0.0
113         acc_sq = 0.0
114         usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
115         for b in range(0, self.nb_batches):
116             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
117             input = svrt.generate_vignettes(problem_number, target)
118
119             # FIXME input_as_float should not be necessary but there
120             # are weird memory leaks going on, which do not seem to be
121             # my fault
122             if b == 0:
123                 input_as_float = input.float()
124             else:
125                 input_as_float.copy_(input)
126             acc += input_as_float.sum() / input.numel()
127             acc_sq += input_as_float.pow(2).sum() /  input.numel()
128
129             self.targets.append(target)
130             self.input_storages.append(svrt.compress(input.storage()))
131             if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
132
133             # FIXME
134             if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
135                 print('Memory leak?!')
136                 raise
137
138         mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024
139         print('Using {:.02f}Gb total {:.02f}b / samples'
140               .format(mem / (1024 * 1024 * 1024), mem / self.nb_samples))
141
142         self.mean = acc / self.nb_batches
143         self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
144
145     def get_batch(self, b):
146         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
147         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
148         target = self.targets[b]
149
150         if self.cuda:
151             input = input.cuda()
152             target = target.cuda()
153
154         return input, target
155
156 ######################################################################