X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=vignette_set.py;h=0b6de7e17d7db8fc4897f844c889acf549cda562;hb=cefdf80cffc5f897dc728d68bf927f522e3e1608;hp=c46beea3b2a809fcd7ff49db2be0e6d0d6bd992e;hpb=abbbb61852f54e90df6ac5b5f4dcb71d06f88f49;p=pysvrt.git diff --git a/vignette_set.py b/vignette_set.py index c46beea..0b6de7e 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -22,7 +22,7 @@ import torch from math import sqrt -from multiprocessing import Pool, cpu_count +from torch import multiprocessing from torch import Tensor from torch.autograd import Variable @@ -32,14 +32,11 @@ import svrt ###################################################################### def generate_one_batch(s): - problem_number, batch_size, cuda, random_seed = s + problem_number, batch_size, random_seed = s svrt.seed(random_seed) target = torch.LongTensor(batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if cuda: - input = input.cuda() - target = target.cuda() return [ input, target ] class VignetteSet: @@ -54,13 +51,16 @@ class VignetteSet: seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] for b in range(0, self.nb_batches): - mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ]) + mp_args.append( [ problem_number, batch_size, seeds[b] ]) - # self.data = [] - # for b in range(0, self.nb_batches): - # self.data.append(generate_one_batch(mp_args[b])) + self.data = [] + for b in range(0, self.nb_batches): + self.data.append(generate_one_batch(mp_args[b])) + + # Weird thing going on with the multi-processing, waiting for more info - self.data = Pool(cpu_count()).map(generate_one_batch, mp_args) + # pool = multiprocessing.Pool(multiprocessing.cpu_count()) + # self.data = pool.map(generate_one_batch, mp_args) acc = 0.0 acc_sq = 0.0 @@ -73,6 +73,9 @@ class VignetteSet: std = sqrt(acc_sq / self.nb_batches - mean * mean) for b in range(0, self.nb_batches): self.data[b][0].sub_(mean).div_(std) + if cuda: + self.data[b][0] = self.data[b][0].cuda() + self.data[b][1] = self.data[b][1].cuda() def get_batch(self, b): return self.data[b]