From a5a653bbebcd89b616918f44543c85140ecdaa33 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 23:55:01 +0200 Subject: [PATCH] Update. --- tasks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tasks.py b/tasks.py index 450a495..9100632 100755 --- a/tasks.py +++ b/tasks.py @@ -671,7 +671,7 @@ class MNIST(Task): self.global_basis_inverse = self.global_basis.inverse() torchvision.utils.save_image( - 1 - self.global_basis / self.global_basis.std(), + 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(), "fourier.png", nrow=28, ) @@ -696,9 +696,9 @@ class MNIST(Task): return y def global_decode(self, y): - y = ( - (y / 255.0) * (2 * self.range) - self.range - ) * self.global_std + self.global_mu + y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to( + y.device + ) + self.global_mu.to(y.device) return y.float() @ self.global_basis_inverse.to(y.device).t() def __init__( -- 2.39.5