X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=vignette_set.py;h=b95a1db75883d548db850bf72a5ea38ae5e27e45;hb=34aeb8100a6c19dae72779f9e46a0acbb5a069c7;hp=19a6f33ead1c4d95604dccbc70f4b1cb3b9c587e;hpb=e2368847af8e2eb5d6dda88b3318b64ec8637667;p=pysvrt.git
diff --git a/vignette_set.py b/vignette_set.py
index 19a6f33..b95a1db 100755
--- a/vignette_set.py
+++ b/vignette_set.py
@@ -18,11 +18,11 @@
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
-# along with selector. If not, see .
+# along with pysvrt. If not, see .
import torch
from math import sqrt
-from torch.multiprocessing import Pool, cpu_count
+from torch import multiprocessing
from torch import Tensor
from torch.autograd import Variable
@@ -41,11 +41,16 @@ def generate_one_batch(s):
class VignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
@@ -53,11 +58,14 @@ class VignetteSet:
for b in range(0, self.nb_batches):
mp_args.append( [ problem_number, batch_size, seeds[b] ])
- # self.data = []
- # for b in range(0, self.nb_batches):
- # self.data.append(generate_one_batch(mp_args[b]))
+ self.data = []
+ for b in range(0, self.nb_batches):
+ self.data.append(generate_one_batch(mp_args[b]))
+
+ # Weird thing going on with the multi-processing, waiting for more info
- self.data = Pool(cpu_count()).map(generate_one_batch, mp_args)
+ # pool = multiprocessing.Pool(multiprocessing.cpu_count())
+ # self.data = pool.map(generate_one_batch, mp_args)
acc = 0.0
acc_sq = 0.0
@@ -80,11 +88,16 @@ class VignetteSet:
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []