X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=svrtset.py;h=54c4dec2efd560810be6ea9e76a5b60c82b162c7;hp=49f32799da7b0ba4a880d9aacc90278617908220;hb=HEAD;hpb=e754d1075d8d0a5949e71f426ab07ce73be6099e diff --git a/svrtset.py b/svrtset.py index 49f3279..54c4dec 100755 --- a/svrtset.py +++ b/svrtset.py @@ -29,6 +29,9 @@ from torch.autograd import Variable import svrt +# FIXME +import resource + ###################################################################### def generate_one_batch(s): @@ -48,10 +51,11 @@ class VignetteSet: raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] @@ -96,24 +100,45 @@ class CompressedVignetteSet: raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size + self.targets = [] self.input_storages = [] acc = 0.0 acc_sq = 0.0 + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss for b in range(0, self.nb_batches): target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) - acc += input.float().sum() / input.numel() - acc_sq += input.float().pow(2).sum() / input.numel() + + # FIXME input_as_float should not be necessary but there + # are weird memory leaks going on, which do not seem to be + # my fault + if b == 0: + input_as_float = input.float() + else: + input_as_float.copy_(input) + acc += input_as_float.sum() / input.numel() + acc_sq += input_as_float.pow(2).sum() / input.numel() + self.targets.append(target) self.input_storages.append(svrt.compress(input.storage())) if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size) + # FIXME + if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6: + print('Memory leak?!') + raise + + mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024 + print('Using {:.02f}Gb total {:.02f}b / samples' + .format(mem / (1024 * 1024 * 1024), mem / self.nb_samples)) + self.mean = acc / self.nb_batches self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)