X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=vignette_set.py;h=5062f3e89d698029d7455fa6937e8ea01eb3129b;hp=ea5215944bddbf7dcf0d5b50efcd20293f76da6a;hb=15f2d2cf0a655234cfa435789e26238b95f5a371;hpb=c71899cfec905c50302be54725a97d7fbff08f54 diff --git a/vignette_set.py b/vignette_set.py index ea52159..5062f3e 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -22,6 +22,7 @@ import torch from math import sqrt +from torch import multiprocessing from torch import Tensor from torch.autograd import Variable @@ -30,45 +31,73 @@ import svrt ###################################################################### +def generate_one_batch(s): + problem_number, batch_size, random_seed = s + svrt.seed(random_seed) + target = torch.LongTensor(batch_size).bernoulli_(0.5) + input = svrt.generate_vignettes(problem_number, target) + input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) + return [ input, target ] + class VignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size - self.targets = [] - self.inputs = [] + + seeds = torch.LongTensor(self.nb_batches).random_() + mp_args = [] + for b in range(0, self.nb_batches): + mp_args.append( [ problem_number, batch_size, seeds[b] ]) + + self.data = [] + for b in range(0, self.nb_batches): + self.data.append(generate_one_batch(mp_args[b])) + + # Weird thing going on with the multi-processing, waiting for more info + + # pool = multiprocessing.Pool(multiprocessing.cpu_count()) + # self.data = pool.map(generate_one_batch, mp_args) acc = 0.0 acc_sq = 0.0 - for b in range(0, self.nb_batches): - target = torch.LongTensor(self.batch_size).bernoulli_(0.5) - input = svrt.generate_vignettes(problem_number, target) - input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if torch.cuda.is_available(): - input = input.cuda() - target = target.cuda() + input = self.data[b][0] acc += input.sum() / input.numel() acc_sq += input.pow(2).sum() / input.numel() - self.targets.append(target) - self.inputs.append(input) mean = acc / self.nb_batches std = sqrt(acc_sq / self.nb_batches - mean * mean) for b in range(0, self.nb_batches): - self.inputs[b].sub_(mean).div_(std) + self.data[b][0].sub_(mean).div_(std) + if cuda: + self.data[b][0] = self.data[b][0].cuda() + self.data[b][1] = self.data[b][1].cuda() def get_batch(self, b): - return self.inputs[b], self.targets[b] + return self.data[b] ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size self.targets = [] self.input_storages = [] @@ -84,14 +113,14 @@ class CompressedVignetteSet: self.input_storages.append(svrt.compress(input.storage())) self.mean = acc / self.nb_batches - self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean) + self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean) def get_batch(self, b): input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float() input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std) target = self.targets[b] - if torch.cuda.is_available(): + if self.cuda: input = input.cuda() target = target.cuda()