From: François Fleuret Date: Sun, 29 Sep 2024 21:55:01 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a5a653bbebcd89b616918f44543c85140ecdaa33;p=picoclvr.git Update. --- diff --git a/tasks.py b/tasks.py index 450a495..9100632 100755 --- a/tasks.py +++ b/tasks.py @@ -671,7 +671,7 @@ class MNIST(Task): self.global_basis_inverse = self.global_basis.inverse() torchvision.utils.save_image( - 1 - self.global_basis / self.global_basis.std(), + 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(), "fourier.png", nrow=28, ) @@ -696,9 +696,9 @@ class MNIST(Task): return y def global_decode(self, y): - y = ( - (y / 255.0) * (2 * self.range) - self.range - ) * self.global_std + self.global_mu + y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to( + y.device + ) + self.global_mu.to(y.device) return y.float() @ self.global_basis_inverse.to(y.device).t() def __init__(