From 3882ebe7fdc72ac5cd3b39bc6d5d4c4924a20a9e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 30 Sep 2024 08:28:37 +0200 Subject: [PATCH] Update. --- tasks.py | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tasks.py b/tasks.py index 5f3258b..9901715 100755 --- a/tasks.py +++ b/tasks.py @@ -665,13 +665,6 @@ class MNIST(Task): def create_fourier_basis(self): self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) self.fourier_basis_inverse = self.fourier_basis.inverse() - - torchvision.utils.save_image( - 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(), - "fourier.png", - nrow=28, - ) - y = self.train_input.float() @ self.fourier_basis.t() self.fourier_range = 4 self.fourier_mu = y.mean(dim=0, keepdim=True) @@ -718,24 +711,6 @@ class MNIST(Task): self.train_input = self.fourier_encode(self.train_input) self.test_input = self.fourier_encode(self.test_input) - torchvision.utils.save_image( - 1 - - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) - / 256, - "check-train.png", - nrow=16, - ) - torchvision.utils.save_image( - 1 - - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) - / 256, - "check-test.png", - nrow=16, - ) - # exit(0) - - print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}") - def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -754,6 +729,26 @@ class MNIST(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): + if n_epoch == 0: + image_name = os.path.join(result_dir, "fourier.png") + torchvision.utils.save_image( + 0.5 + - 0.5 + * self.fourier_basis.reshape(-1, 1, 28, 28) + / self.fourier_basis.std(), + image_name, + nrow=28, + ) + + image_name = os.path.join(result_dir, "check-train.png") + torchvision.utils.save_image( + 1 + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + image_name, + nrow=16, + ) + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) ar_mask = torch.full_like(results, 1) masked_inplace_autoregression( -- 2.39.5