for b in range(0, self.nb_batches):
target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
- acc += float(input.sum()) / input.numel()
- acc_sq += float((input * input).sum()) / input.numel()
+
+ # FIXME input_as_float should not be necessary but there
+ # are weird memory leaks going on, which do not seem to be
+ # my fault
+ if b == 0:
+ input_as_float = input.float()
+ else:
+ input_as_float.copy_(input)
+ acc += input_as_float.sum() / input.numel()
+ acc_sq += input_as_float.pow(2).sum() / input.numel()
+
self.targets.append(target)
self.input_storages.append(svrt.compress(input.storage()))
if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)