From 605697b42bdf62c0d8a6715d43ab40b7446e9af2 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 16 Jun 2017 09:43:52 +0200 Subject: [PATCH] Made VignetteSet.__init__ multi-proc. --- vignette_set.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/vignette_set.py b/vignette_set.py index 695fed3..72880ba 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -22,6 +22,7 @@ import torch from math import sqrt +from multiprocessing import Pool, cpu_count from torch import Tensor from torch.autograd import Variable @@ -30,38 +31,47 @@ import svrt ###################################################################### +def generate_one_batch(s): + svrt.seed(s) + target = torch.LongTensor(self.batch_size).bernoulli_(0.5) + input = svrt.generate_vignettes(problem_number, target) + input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) + if self.cuda: + input = input.cuda() + target = target.cuda() + return [ input, target ] + class VignetteSet: + def __init__(self, problem_number, nb_batches, batch_size, cuda = False): self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number self.nb_batches = nb_batches self.nb_samples = self.nb_batches * self.batch_size - self.targets = [] - self.inputs = [] + + seed_list = torch.LongTensor(self.nb_batches).random_().tolist() + + # self.data = [] + # for b in range(0, self.nb_batches): + # self.data.append(generate_one_batch(seed_list[b])) + + self.data = Pool(cpu_count()).map(generate_one_batch, seed_list) acc = 0.0 acc_sq = 0.0 - for b in range(0, self.nb_batches): - target = torch.LongTensor(self.batch_size).bernoulli_(0.5) - input = svrt.generate_vignettes(problem_number, target) - input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if self.cuda: - input = input.cuda() - target = target.cuda() + input = self.data[b][0] acc += input.sum() / input.numel() acc_sq += input.pow(2).sum() / input.numel() - self.targets.append(target) - self.inputs.append(input) mean = acc / self.nb_batches std = sqrt(acc_sq / self.nb_batches - mean * mean) for b in range(0, self.nb_batches): - self.inputs[b].sub_(mean).div_(std) + self.data[b][0].sub_(mean).div_(std) def get_batch(self, b): - return self.inputs[b], self.targets[b] + return self.data[b] ###################################################################### -- 2.39.5