######################################################################
+def generate_2d_fourier_basis(T):
+ # Create 1D vectors for time/space in both dimensions
+ t = torch.linspace(0, T - 1, T)
+
+ # Initialize an empty list to hold the basis vectors
+ basis = [torch.ones(T, T)] # The constant (DC) component
+
+ # Generate cosine and sine terms for both dimensions
+ for nx in range(1, T // 2 + 1):
+ for ny in range(1, T // 2 + 1):
+ # Cosine and sine components in x- and y-directions
+ cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1)
+ sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1)
+ cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0)
+ sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0)
+
+ # Basis functions in 2D as outer products
+ basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny)
+ basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny)
+ basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny)
+ basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny)
+
+ # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T)
+ basis_matrix = torch.stack(basis[: T * T], dim=0)
+
+ return basis_matrix
+
+
class MNIST(Task):
- def fourier_encode(self, x):
- y = torch.linalg.lstsq(
- self.fourier_basis.to(x.device).t(), x.float().t()
- ).solution.t()
- y = ((y / 255).clamp(min=0, max=1) * 255).long()
+ def create_global_basis(self):
+ import numpy as np
+
+ # self.global_basis = torch.randn(784, 784)
+
+ self.global_basis = generate_2d_fourier_basis(T=28).flatten(1)
+ self.global_basis_inverse = self.global_basis.inverse()
+
+ torchvision.utils.save_image(
+ 1 - self.global_basis / self.global_basis.std(),
+ "fourier.png",
+ nrow=28,
+ )
+
+ y = self.train_input.float() @ self.global_basis.t()
+ self.range = 4
+ self.global_mu = y.mean(dim=0, keepdim=True)
+ self.global_std = y.std(dim=0, keepdim=True)
+
+ # for k in range(25):
+ # X = self.global_encode(self.train_input).float()
+ # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item())
+ # sys.stdout.flush()
+ # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution
+
+ def global_encode(self, x):
+ y = x.float() @ self.global_basis.t()
+ y = ((y - self.global_mu) / self.global_std).clamp(
+ min=-self.range, max=self.range
+ )
+ y = (((y + self.range) / (2 * self.range)) * 255).long()
return y
- def fourier_decode(self, y):
- return y.float() @ self.fourier_basis.to(y.device)
+ def global_decode(self, y):
+ y = (
+ (y / 255.0) * (2 * self.range) - self.range
+ ) * self.global_std + self.global_mu
+ return y.float() @ self.global_basis_inverse.to(y.device).t()
def __init__(
self,
nb_train_samples,
nb_test_samples,
batch_size,
- fourier=True,
+ global_representation=True,
device=torch.device("cpu"),
):
super().__init__()
data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
- self.fourier = fourier
-
- if fourier:
- self.create_fourier_basis()
+ self.global_representation = global_representation
- # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}")
+ if global_representation:
+ self.create_global_basis()
- self.train_input = self.fourier_encode(self.train_input[:256])
- self.test_input = self.fourier_encode(self.test_input[:256])
+ self.train_input = self.global_encode(self.train_input)
+ self.test_input = self.global_encode(self.test_input)
torchvision.utils.save_image(
- 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256,
- "train.png",
+ 1
+ - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+ / 256,
+ "check-train.png",
nrow=16,
)
torchvision.utils.save_image(
- 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256,
- "test.png",
+ 1
+ - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
+ / 256,
+ "check-test.png",
nrow=16,
)
# exit(0)
device=self.device,
)
- results = self.fourier_decode(results)
+ if self.global_representation:
+ results = self.global_decode(results)
image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
torchvision.utils.save_image(