Update.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2017 06:35:16 +0000 (08:35 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2017 06:35:16 +0000 (08:35 +0200)
svrtset.py

index 0a14121..ecbaa68 100755 (executable)
@@ -114,8 +114,8 @@ class CompressedVignetteSet:
         for b in range(0, self.nb_batches):
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
-            acc += input.sum() / input.numel()
-            acc_sq += (input * input).sum() /  input.numel()
+            acc += float(input.sum()) / input.numel()
+            acc_sq += float((input * input).sum()) /  input.numel()
             self.targets.append(target)
             self.input_storages.append(svrt.compress(input.storage()))
             if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)