X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=vignette_set.py;h=b95a1db75883d548db850bf72a5ea38ae5e27e45;hb=34aeb8100a6c19dae72779f9e46a0acbb5a069c7;hp=19a6f33ead1c4d95604dccbc70f4b1cb3b9c587e;hpb=e2368847af8e2eb5d6dda88b3318b64ec8637667;p=pysvrt.git diff --git a/vignette_set.py b/vignette_set.py index 19a6f33..b95a1db 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -18,11 +18,11 @@ # General Public License for more details. # # You should have received a copy of the GNU General Public License -# along with selector. If not, see . +# along with pysvrt. If not, see . import torch from math import sqrt -from torch.multiprocessing import Pool, cpu_count +from torch import multiprocessing from torch import Tensor from torch.autograd import Variable @@ -41,11 +41,16 @@ def generate_one_batch(s): class VignetteSet: - def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size seeds = torch.LongTensor(self.nb_batches).random_() @@ -53,11 +58,14 @@ class VignetteSet: for b in range(0, self.nb_batches): mp_args.append( [ problem_number, batch_size, seeds[b] ]) - # self.data = [] - # for b in range(0, self.nb_batches): - # self.data.append(generate_one_batch(mp_args[b])) + self.data = [] + for b in range(0, self.nb_batches): + self.data.append(generate_one_batch(mp_args[b])) + + # Weird thing going on with the multi-processing, waiting for more info - self.data = Pool(cpu_count()).map(generate_one_batch, mp_args) + # pool = multiprocessing.Pool(multiprocessing.cpu_count()) + # self.data = pool.map(generate_one_batch, mp_args) acc = 0.0 acc_sq = 0.0 @@ -80,11 +88,16 @@ class VignetteSet: ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size self.targets = [] self.input_storages = []