X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=svrtset.py;h=8bfdde199bd1c858d30ee42fe3e0ebbe30eeb860;hb=faac6461d3482204898ad97f811f58889cc37e1d;hp=cbc71a373237356007bca386de4890036ac902e5;hpb=b6cc25f4622917b0025d613d617b5ce9e23f07e9;p=pysvrt.git diff --git a/svrtset.py b/svrtset.py index cbc71a3..8bfdde1 100755 --- a/svrtset.py +++ b/svrtset.py @@ -41,17 +41,18 @@ def generate_one_batch(s): class VignetteSet: - def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None): if nb_samples%batch_size > 0: print('nb_samples must be a multiple of batch_size') raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] @@ -61,6 +62,7 @@ class VignetteSet: self.data = [] for b in range(0, self.nb_batches): self.data.append(generate_one_batch(mp_args[b])) + if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size) # Weird thing going on with the multi-processing, waiting for more info @@ -88,17 +90,19 @@ class VignetteSet: ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None): if nb_samples%batch_size > 0: print('nb_samples must be a multiple of batch_size') raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size + self.targets = [] self.input_storages = [] @@ -111,6 +115,7 @@ class CompressedVignetteSet: acc_sq += input.float().pow(2).sum() / input.numel() self.targets.append(target) self.input_storages.append(svrt.compress(input.storage())) + if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size) self.mean = acc / self.nb_batches self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)