Cosmetics.
[pysvrt.git] / vignette_set.py
index 72880ba..5062f3e 100755 (executable)
@@ -22,7 +22,7 @@
 
 import torch
 from math import sqrt
-from multiprocessing import Pool, cpu_count
+from torch import multiprocessing
 
 from torch import Tensor
 from torch.autograd import Variable
@@ -32,31 +32,40 @@ import svrt
 ######################################################################
 
 def generate_one_batch(s):
-    svrt.seed(s)
-    target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
+    problem_number, batch_size, random_seed = s
+    svrt.seed(random_seed)
+    target = torch.LongTensor(batch_size).bernoulli_(0.5)
     input = svrt.generate_vignettes(problem_number, target)
     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-    if self.cuda:
-        input = input.cuda()
-        target = target.cuda()
     return [ input, target ]
 
 class VignetteSet:
 
-    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+    def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+        if nb_samples%batch_size > 0:
+            print('nb_samples must be a mutiple of batch_size')
+            raise
+
         self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_batches
+        self.nb_batches = nb_samples // batch_size
         self.nb_samples = self.nb_batches * self.batch_size
 
-        seed_list = torch.LongTensor(self.nb_batches).random_().tolist()
+        seeds = torch.LongTensor(self.nb_batches).random_()
+        mp_args = []
+        for b in range(0, self.nb_batches):
+            mp_args.append( [ problem_number, batch_size, seeds[b] ])
 
-        self.data = []
-        for b in range(0, self.nb_batches):
-            # self.data.append(generate_one_batch(seed_list[b]))
+        self.data = []
+        for b in range(0, self.nb_batches):
+            self.data.append(generate_one_batch(mp_args[b]))
 
-        self.data = Pool(cpu_count()).map(generate_one_batch, seed_list)
+        # Weird thing going on with the multi-processing, waiting for more info
+
+        # pool = multiprocessing.Pool(multiprocessing.cpu_count())
+        # self.data = pool.map(generate_one_batch, mp_args)
 
         acc = 0.0
         acc_sq = 0.0
@@ -69,6 +78,9 @@ class VignetteSet:
         std = sqrt(acc_sq / self.nb_batches - mean * mean)
         for b in range(0, self.nb_batches):
             self.data[b][0].sub_(mean).div_(std)
+            if cuda:
+                self.data[b][0] = self.data[b][0].cuda()
+                self.data[b][1] = self.data[b][1].cuda()
 
     def get_batch(self, b):
         return self.data[b]
@@ -76,11 +88,16 @@ class VignetteSet:
 ######################################################################
 
 class CompressedVignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+    def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+        if nb_samples%batch_size > 0:
+            print('nb_samples must be a mutiple of batch_size')
+            raise
+
         self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_batches
+        self.nb_batches = nb_samples // batch_size
         self.nb_samples = self.nb_batches * self.batch_size
         self.targets = []
         self.input_storages = []