From: Francois Fleuret Date: Wed, 21 Jun 2017 06:33:05 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a4820693a16a173fd21b8a7580342fe0ad9e8af8;p=pysvrt.git Update. --- diff --git a/svrtset.py b/svrtset.py index 8bfdde1..0a14121 100755 --- a/svrtset.py +++ b/svrtset.py @@ -29,6 +29,9 @@ from torch.autograd import Variable import svrt +# FIXME +import resource + ###################################################################### def generate_one_batch(s): @@ -111,12 +114,17 @@ class CompressedVignetteSet: for b in range(0, self.nb_batches): target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) - acc += input.float().sum() / input.numel() - acc_sq += input.float().pow(2).sum() / input.numel() + acc += input.sum() / input.numel() + acc_sq += (input * input).sum() / input.numel() self.targets.append(target) self.input_storages.append(svrt.compress(input.storage())) if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size) + # FIXME + if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6: + print('Memory leak?!') + raise + self.mean = acc / self.nb_batches self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)