self.global_basis_inverse = self.global_basis.inverse()
 
         torchvision.utils.save_image(
-            1 - self.global_basis / self.global_basis.std(),
+            1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(),
             "fourier.png",
             nrow=28,
         )
         return y
 
     def global_decode(self, y):
-        y = (
-            (y / 255.0) * (2 * self.range) - self.range
-        ) * self.global_std + self.global_mu
+        y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to(
+            y.device
+        ) + self.global_mu.to(y.device)
         return y.float() @ self.global_basis_inverse.to(y.device).t()
 
     def __init__(