From 8c709c982d64948ab1c8949930dc0468f91039aa Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 21 Jun 2017 08:48:16 +0200 Subject: [PATCH] Try to fix weird memory leaks with ugly hacks. --- svrtset.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/svrtset.py b/svrtset.py index ecbaa68..4022752 100755 --- a/svrtset.py +++ b/svrtset.py @@ -114,8 +114,17 @@ class CompressedVignetteSet: for b in range(0, self.nb_batches): target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) - acc += float(input.sum()) / input.numel() - acc_sq += float((input * input).sum()) / input.numel() + + # FIXME input_as_float should not be necessary but there + # are weird memory leaks going on, which do not seem to be + # my fault + if b == 0: + input_as_float = input.float() + else: + input_as_float.copy_(input) + acc += input_as_float.sum() / input.numel() + acc_sq += input_as_float.pow(2).sum() / input.numel() + self.targets.append(target) self.input_storages.append(svrt.compress(input.storage())) if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size) -- 2.39.5