import svrt
+# FIXME
+import resource
+
######################################################################
def generate_one_batch(s):
raise
self.cuda = cuda
- self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
raise
self.cuda = cuda
- self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
+
self.targets = []
self.input_storages = []
acc = 0.0
acc_sq = 0.0
+ usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
for b in range(0, self.nb_batches):
target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
- acc += input.float().sum() / input.numel()
- acc_sq += input.float().pow(2).sum() / input.numel()
+
+ # FIXME input_as_float should not be necessary but there
+ # are weird memory leaks going on, which do not seem to be
+ # my fault
+ if b == 0:
+ input_as_float = input.float()
+ else:
+ input_as_float.copy_(input)
+ acc += input_as_float.sum() / input.numel()
+ acc_sq += input_as_float.pow(2).sum() / input.numel()
+
self.targets.append(target)
self.input_storages.append(svrt.compress(input.storage()))
if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
+ # FIXME
+ if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
+ print('Memory leak?!')
+ raise
+
+ mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024
+ print('Using {:.02f}Gb total {:.02f}b / samples'
+ .format(mem / (1024 * 1024 * 1024), mem / self.nb_samples))
+
self.mean = acc / self.nb_batches
self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)