X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=svrtset.py;h=ecbaa68b992b81c20ba6a2874a4a74152e78566e;hp=0a14121c9a838fb15db18e5b3f114dad2f695e60;hb=a500b3545f6eb25c480e945f2e12786933c92423;hpb=a4820693a16a173fd21b8a7580342fe0ad9e8af8 diff --git a/svrtset.py b/svrtset.py index 0a14121..ecbaa68 100755 --- a/svrtset.py +++ b/svrtset.py @@ -114,8 +114,8 @@ class CompressedVignetteSet: for b in range(0, self.nb_batches): target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) - acc += input.sum() / input.numel() - acc_sq += (input * input).sum() / input.numel() + acc += float(input.sum()) / input.numel() + acc_sq += float((input * input).sum()) / input.numel() self.targets.append(target) self.input_storages.append(svrt.compress(input.storage())) if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)