Minor update.
[pysvrt.git] / svrtset.py
index 49f3279..54c4dec 100755 (executable)
@@ -29,6 +29,9 @@ from torch.autograd import Variable
 
 import svrt
 
+# FIXME
+import resource
+
 ######################################################################
 
 def generate_one_batch(s):
@@ -48,10 +51,11 @@ class VignetteSet:
             raise
 
         self.cuda = cuda
-        self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_samples // batch_size
-        self.nb_samples = self.nb_batches * self.batch_size
+
+        self.batch_size = batch_size
+        self.nb_samples = nb_samples
+        self.nb_batches = self.nb_samples // self.batch_size
 
         seeds = torch.LongTensor(self.nb_batches).random_()
         mp_args = []
@@ -96,24 +100,45 @@ class CompressedVignetteSet:
             raise
 
         self.cuda = cuda
-        self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_samples // batch_size
-        self.nb_samples = self.nb_batches * self.batch_size
+
+        self.batch_size = batch_size
+        self.nb_samples = nb_samples
+        self.nb_batches = self.nb_samples // self.batch_size
+
         self.targets = []
         self.input_storages = []
 
         acc = 0.0
         acc_sq = 0.0
+        usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
         for b in range(0, self.nb_batches):
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
-            acc += input.float().sum() / input.numel()
-            acc_sq += input.float().pow(2).sum() /  input.numel()
+
+            # FIXME input_as_float should not be necessary but there
+            # are weird memory leaks going on, which do not seem to be
+            # my fault
+            if b == 0:
+                input_as_float = input.float()
+            else:
+                input_as_float.copy_(input)
+            acc += input_as_float.sum() / input.numel()
+            acc_sq += input_as_float.pow(2).sum() /  input.numel()
+
             self.targets.append(target)
             self.input_storages.append(svrt.compress(input.storage()))
             if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
 
+            # FIXME
+            if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
+                print('Memory leak?!')
+                raise
+
+        mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024
+        print('Using {:.02f}Gb total {:.02f}b / samples'
+              .format(mem / (1024 * 1024 * 1024), mem / self.nb_samples))
+
         self.mean = acc / self.nb_batches
         self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)