self.global_basis_inverse = self.global_basis.inverse()
torchvision.utils.save_image(
- 1 - self.global_basis / self.global_basis.std(),
+ 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(),
"fourier.png",
nrow=28,
)
return y
def global_decode(self, y):
- y = (
- (y / 255.0) * (2 * self.range) - self.range
- ) * self.global_std + self.global_mu
+ y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to(
+ y.device
+ ) + self.global_mu.to(y.device)
return y.float() @ self.global_basis_inverse.to(y.device).t()
def __init__(