From 0057bf10615d045a043014cfe1bcb01ba9c3871d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 30 Sep 2024 08:16:20 +0200 Subject: [PATCH] Update. --- main.py | 6 +++++ tasks.py | 68 ++++++++++++++++++++++++-------------------------------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/main.py b/main.py index 3ff64b7..fe7b49e 100755 --- a/main.py +++ b/main.py @@ -88,6 +88,11 @@ parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") +############################## +# MNIST + +parser.add_argument("--mnist_fourier", action="store_true", default=False) + ############################## # filetask @@ -546,6 +551,7 @@ elif args.task == "mnist": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, + fourier_representation=args.mnist_fourier, device=device, ) diff --git a/tasks.py b/tasks.py index 9100632..5f3258b 100755 --- a/tasks.py +++ b/tasks.py @@ -662,51 +662,41 @@ def generate_2d_fourier_basis(T): class MNIST(Task): - def create_global_basis(self): - import numpy as np - - # self.global_basis = torch.randn(784, 784) - - self.global_basis = generate_2d_fourier_basis(T=28).flatten(1) - self.global_basis_inverse = self.global_basis.inverse() + def create_fourier_basis(self): + self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.fourier_basis_inverse = self.fourier_basis.inverse() torchvision.utils.save_image( - 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(), + 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(), "fourier.png", nrow=28, ) - y = self.train_input.float() @ self.global_basis.t() - self.range = 4 - self.global_mu = y.mean(dim=0, keepdim=True) - self.global_std = y.std(dim=0, keepdim=True) - - # for k in range(25): - # X = self.global_encode(self.train_input).float() - # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item()) - # sys.stdout.flush() - # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution - - def global_encode(self, x): - y = x.float() @ self.global_basis.t() - y = ((y - self.global_mu) / self.global_std).clamp( - min=-self.range, max=self.range + y = self.train_input.float() @ self.fourier_basis.t() + self.fourier_range = 4 + self.fourier_mu = y.mean(dim=0, keepdim=True) + self.fourier_std = y.std(dim=0, keepdim=True) + + def fourier_encode(self, x): + y = x.float() @ self.fourier_basis.t() + y = ((y - self.fourier_mu) / self.fourier_std).clamp( + min=-self.fourier_range, max=self.fourier_range ) - y = (((y + self.range) / (2 * self.range)) * 255).long() + y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long() return y - def global_decode(self, y): - y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to( - y.device - ) + self.global_mu.to(y.device) - return y.float() @ self.global_basis_inverse.to(y.device).t() + def fourier_decode(self, y): + y = ( + (y / 255.0) * (2 * self.fourier_range) - self.fourier_range + ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device) + return y.float() @ self.fourier_basis_inverse.to(y.device).t() def __init__( self, nb_train_samples, nb_test_samples, batch_size, - global_representation=True, + fourier_representation=True, device=torch.device("cpu"), ): super().__init__() @@ -720,24 +710,24 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() - self.global_representation = global_representation + self.fourier_representation = fourier_representation - if global_representation: - self.create_global_basis() + if fourier_representation: + self.create_fourier_basis() - self.train_input = self.global_encode(self.train_input) - self.test_input = self.global_encode(self.test_input) + self.train_input = self.fourier_encode(self.train_input) + self.test_input = self.fourier_encode(self.test_input) torchvision.utils.save_image( 1 - - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) / 256, "check-train.png", nrow=16, ) torchvision.utils.save_image( 1 - - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) + - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) / 256, "check-test.png", nrow=16, @@ -775,8 +765,8 @@ class MNIST(Task): device=self.device, ) - if self.global_representation: - results = self.global_decode(results) + if self.fourier_representation: + results = self.fourier_decode(results) image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( -- 2.39.5