Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 21:55:01 +0000 (23:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 21:55:01 +0000 (23:55 +0200)
tasks.py

index 450a495..9100632 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -671,7 +671,7 @@ class MNIST(Task):
         self.global_basis_inverse = self.global_basis.inverse()
 
         torchvision.utils.save_image(
-            1 - self.global_basis / self.global_basis.std(),
+            1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(),
             "fourier.png",
             nrow=28,
         )
@@ -696,9 +696,9 @@ class MNIST(Task):
         return y
 
     def global_decode(self, y):
-        y = (
-            (y / 255.0) * (2 * self.range) - self.range
-        ) * self.global_std + self.global_mu
+        y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to(
+            y.device
+        ) + self.global_mu.to(y.device)
         return y.float() @ self.global_basis_inverse.to(y.device).t()
 
     def __init__(