From 80c23832f3c8a29ef684ec68e580c8af68c43414 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 23:06:25 +0200 Subject: [PATCH] Update. --- tasks.py | 100 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/tasks.py b/tasks.py index 4ec410f..450a495 100755 --- a/tasks.py +++ b/tasks.py @@ -633,23 +633,80 @@ class PicoCLVR(Task): ###################################################################### +def generate_2d_fourier_basis(T): + # Create 1D vectors for time/space in both dimensions + t = torch.linspace(0, T - 1, T) + + # Initialize an empty list to hold the basis vectors + basis = [torch.ones(T, T)] # The constant (DC) component + + # Generate cosine and sine terms for both dimensions + for nx in range(1, T // 2 + 1): + for ny in range(1, T // 2 + 1): + # Cosine and sine components in x- and y-directions + cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1) + sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1) + cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0) + sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0) + + # Basis functions in 2D as outer products + basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny) + basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny) + basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny) + basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny) + + # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T) + basis_matrix = torch.stack(basis[: T * T], dim=0) + + return basis_matrix + + class MNIST(Task): - def fourier_encode(self, x): - y = torch.linalg.lstsq( - self.fourier_basis.to(x.device).t(), x.float().t() - ).solution.t() - y = ((y / 255).clamp(min=0, max=1) * 255).long() + def create_global_basis(self): + import numpy as np + + # self.global_basis = torch.randn(784, 784) + + self.global_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.global_basis_inverse = self.global_basis.inverse() + + torchvision.utils.save_image( + 1 - self.global_basis / self.global_basis.std(), + "fourier.png", + nrow=28, + ) + + y = self.train_input.float() @ self.global_basis.t() + self.range = 4 + self.global_mu = y.mean(dim=0, keepdim=True) + self.global_std = y.std(dim=0, keepdim=True) + + # for k in range(25): + # X = self.global_encode(self.train_input).float() + # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item()) + # sys.stdout.flush() + # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution + + def global_encode(self, x): + y = x.float() @ self.global_basis.t() + y = ((y - self.global_mu) / self.global_std).clamp( + min=-self.range, max=self.range + ) + y = (((y + self.range) / (2 * self.range)) * 255).long() return y - def fourier_decode(self, y): - return y.float() @ self.fourier_basis.to(y.device) + def global_decode(self, y): + y = ( + (y / 255.0) * (2 * self.range) - self.range + ) * self.global_std + self.global_mu + return y.float() @ self.global_basis_inverse.to(y.device).t() def __init__( self, nb_train_samples, nb_test_samples, batch_size, - fourier=True, + global_representation=True, device=torch.device("cpu"), ): super().__init__() @@ -663,24 +720,26 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() - self.fourier = fourier - - if fourier: - self.create_fourier_basis() + self.global_representation = global_representation - # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}") + if global_representation: + self.create_global_basis() - self.train_input = self.fourier_encode(self.train_input[:256]) - self.test_input = self.fourier_encode(self.test_input[:256]) + self.train_input = self.global_encode(self.train_input) + self.test_input = self.global_encode(self.test_input) torchvision.utils.save_image( - 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256, - "train.png", + 1 + - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + "check-train.png", nrow=16, ) torchvision.utils.save_image( - 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256, - "test.png", + 1 + - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) + / 256, + "check-test.png", nrow=16, ) # exit(0) @@ -716,7 +775,8 @@ class MNIST(Task): device=self.device, ) - results = self.fourier_decode(results) + if self.global_representation: + results = self.global_decode(results) image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( -- 2.39.5