class MNIST(Task):
- def create_global_basis(self):
- import numpy as np
-
- # self.global_basis = torch.randn(784, 784)
-
- self.global_basis = generate_2d_fourier_basis(T=28).flatten(1)
- self.global_basis_inverse = self.global_basis.inverse()
+ def create_fourier_basis(self):
+ self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
+ self.fourier_basis_inverse = self.fourier_basis.inverse()
torchvision.utils.save_image(
- 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(),
+ 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(),
"fourier.png",
nrow=28,
)
- y = self.train_input.float() @ self.global_basis.t()
- self.range = 4
- self.global_mu = y.mean(dim=0, keepdim=True)
- self.global_std = y.std(dim=0, keepdim=True)
-
- # for k in range(25):
- # X = self.global_encode(self.train_input).float()
- # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item())
- # sys.stdout.flush()
- # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution
-
- def global_encode(self, x):
- y = x.float() @ self.global_basis.t()
- y = ((y - self.global_mu) / self.global_std).clamp(
- min=-self.range, max=self.range
+ y = self.train_input.float() @ self.fourier_basis.t()
+ self.fourier_range = 4
+ self.fourier_mu = y.mean(dim=0, keepdim=True)
+ self.fourier_std = y.std(dim=0, keepdim=True)
+
+ def fourier_encode(self, x):
+ y = x.float() @ self.fourier_basis.t()
+ y = ((y - self.fourier_mu) / self.fourier_std).clamp(
+ min=-self.fourier_range, max=self.fourier_range
)
- y = (((y + self.range) / (2 * self.range)) * 255).long()
+ y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long()
return y
- def global_decode(self, y):
- y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to(
- y.device
- ) + self.global_mu.to(y.device)
- return y.float() @ self.global_basis_inverse.to(y.device).t()
+ def fourier_decode(self, y):
+ y = (
+ (y / 255.0) * (2 * self.fourier_range) - self.fourier_range
+ ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device)
+ return y.float() @ self.fourier_basis_inverse.to(y.device).t()
def __init__(
self,
nb_train_samples,
nb_test_samples,
batch_size,
- global_representation=True,
+ fourier_representation=True,
device=torch.device("cpu"),
):
super().__init__()
data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
- self.global_representation = global_representation
+ self.fourier_representation = fourier_representation
- if global_representation:
- self.create_global_basis()
+ if fourier_representation:
+ self.create_fourier_basis()
- self.train_input = self.global_encode(self.train_input)
- self.test_input = self.global_encode(self.test_input)
+ self.train_input = self.fourier_encode(self.train_input)
+ self.test_input = self.fourier_encode(self.test_input)
torchvision.utils.save_image(
1
- - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+ - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
/ 256,
"check-train.png",
nrow=16,
)
torchvision.utils.save_image(
1
- - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
+ - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
/ 256,
"check-test.png",
nrow=16,
device=self.device,
)
- if self.global_representation:
- results = self.global_decode(results)
+ if self.fourier_representation:
+ results = self.fourier_decode(results)
image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
torchvision.utils.save_image(