From 5f1df36dae9888175a44f2c0d9651f67a457904c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 10:59:36 +0200 Subject: [PATCH] Update. --- tasks.py | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 443419e..4ec410f 100755 --- a/tasks.py +++ b/tasks.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -634,8 +634,23 @@ class PicoCLVR(Task): class MNIST(Task): + def fourier_encode(self, x): + y = torch.linalg.lstsq( + self.fourier_basis.to(x.device).t(), x.float().t() + ).solution.t() + y = ((y / 255).clamp(min=0, max=1) * 255).long() + return y + + def fourier_decode(self, y): + return y.float() @ self.fourier_basis.to(y.device) + def __init__( - self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + self, + nb_train_samples, + nb_test_samples, + batch_size, + fourier=True, + device=torch.device("cpu"), ): super().__init__() @@ -648,6 +663,30 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + self.fourier = fourier + + if fourier: + self.create_fourier_basis() + + # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}") + + self.train_input = self.fourier_encode(self.train_input[:256]) + self.test_input = self.fourier_encode(self.test_input[:256]) + + torchvision.utils.save_image( + 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256, + "train.png", + nrow=16, + ) + torchvision.utils.save_image( + 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256, + "test.png", + nrow=16, + ) + # exit(0) + + print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}") + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -676,6 +715,9 @@ class MNIST(Task): deterministic_synthesis, device=self.device, ) + + results = self.fourier_decode(results) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, -- 2.39.5