X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=vignette_set.py;h=695fed3442b862faee2984adb406a946163aad5e;hb=08ef6b7c332153cd72b7a225e27ee7af8882f313;hp=ea5215944bddbf7dcf0d5b50efcd20293f76da6a;hpb=c71899cfec905c50302be54725a97d7fbff08f54;p=pysvrt.git diff --git a/vignette_set.py b/vignette_set.py index ea52159..695fed3 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -31,7 +31,8 @@ import svrt ###################################################################### class VignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number self.nb_batches = nb_batches @@ -46,7 +47,7 @@ class VignetteSet: target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if torch.cuda.is_available(): + if self.cuda: input = input.cuda() target = target.cuda() acc += input.sum() / input.numel() @@ -65,7 +66,8 @@ class VignetteSet: ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number self.nb_batches = nb_batches @@ -84,14 +86,14 @@ class CompressedVignetteSet: self.input_storages.append(svrt.compress(input.storage())) self.mean = acc / self.nb_batches - self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean) + self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean) def get_batch(self, b): input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float() input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std) target = self.targets[b] - if torch.cuda.is_available(): + if self.cuda: input = input.cuda() target = target.cuda()