Update.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2017 06:33:05 +0000 (08:33 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2017 06:33:05 +0000 (08:33 +0200)
svrtset.py

index 8bfdde1..0a14121 100755 (executable)
@@ -29,6 +29,9 @@ from torch.autograd import Variable
 
 import svrt
 
+# FIXME
+import resource
+
 ######################################################################
 
 def generate_one_batch(s):
@@ -111,12 +114,17 @@ class CompressedVignetteSet:
         for b in range(0, self.nb_batches):
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
-            acc += input.float().sum() / input.numel()
-            acc_sq += input.float().pow(2).sum() /  input.numel()
+            acc += input.sum() / input.numel()
+            acc_sq += (input * input).sum() /  input.numel()
             self.targets.append(target)
             self.input_storages.append(svrt.compress(input.storage()))
             if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
 
+            # FIXME
+            if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
+                print('Memory leak?!')
+                raise
+
         self.mean = acc / self.nb_batches
         self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)