From: Francois Fleuret Date: Fri, 16 Jun 2017 08:39:11 +0000 (+0200) Subject: Trying to make multiprocessing and cuda friends with each other. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e2368847af8e2eb5d6dda88b3318b64ec8637667;p=pysvrt.git Trying to make multiprocessing and cuda friends with each other. --- diff --git a/vignette_set.py b/vignette_set.py index c46beea..19a6f33 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -22,7 +22,7 @@ import torch from math import sqrt -from multiprocessing import Pool, cpu_count +from torch.multiprocessing import Pool, cpu_count from torch import Tensor from torch.autograd import Variable @@ -32,14 +32,11 @@ import svrt ###################################################################### def generate_one_batch(s): - problem_number, batch_size, cuda, random_seed = s + problem_number, batch_size, random_seed = s svrt.seed(random_seed) target = torch.LongTensor(batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if cuda: - input = input.cuda() - target = target.cuda() return [ input, target ] class VignetteSet: @@ -54,7 +51,7 @@ class VignetteSet: seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] for b in range(0, self.nb_batches): - mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ]) + mp_args.append( [ problem_number, batch_size, seeds[b] ]) # self.data = [] # for b in range(0, self.nb_batches): @@ -73,6 +70,9 @@ class VignetteSet: std = sqrt(acc_sq / self.nb_batches - mean * mean) for b in range(0, self.nb_batches): self.data[b][0].sub_(mean).div_(std) + if cuda: + self.data[b][0] = self.data[b][0].cuda() + self.data[b][1] = self.data[b][1].cuda() def get_batch(self, b): return self.data[b]