Bug fix. Ouch, that was bad.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 19 Jun 2017 16:26:17 +0000 (18:26 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 19 Jun 2017 16:26:17 +0000 (18:26 +0200)
svrtset.py

index 49f3279..8bfdde1 100755 (executable)
@@ -48,10 +48,11 @@ class VignetteSet:
             raise
 
         self.cuda = cuda
-        self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_samples // batch_size
-        self.nb_samples = self.nb_batches * self.batch_size
+
+        self.batch_size = batch_size
+        self.nb_samples = nb_samples
+        self.nb_batches = self.nb_samples // self.batch_size
 
         seeds = torch.LongTensor(self.nb_batches).random_()
         mp_args = []
@@ -96,10 +97,12 @@ class CompressedVignetteSet:
             raise
 
         self.cuda = cuda
-        self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_samples // batch_size
-        self.nb_samples = self.nb_batches * self.batch_size
+
+        self.batch_size = batch_size
+        self.nb_samples = nb_samples
+        self.nb_batches = self.nb_samples // self.batch_size
+
         self.targets = []
         self.input_storages = []