Update. master
authorFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000 (08:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000 (08:28 +0200)
tasks.py

index 5f3258b..9901715 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -665,13 +665,6 @@ class MNIST(Task):
     def create_fourier_basis(self):
         self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
         self.fourier_basis_inverse = self.fourier_basis.inverse()
-
-        torchvision.utils.save_image(
-            1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(),
-            "fourier.png",
-            nrow=28,
-        )
-
         y = self.train_input.float() @ self.fourier_basis.t()
         self.fourier_range = 4
         self.fourier_mu = y.mean(dim=0, keepdim=True)
@@ -718,24 +711,6 @@ class MNIST(Task):
             self.train_input = self.fourier_encode(self.train_input)
             self.test_input = self.fourier_encode(self.test_input)
 
-            torchvision.utils.save_image(
-                1
-                - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
-                / 256,
-                "check-train.png",
-                nrow=16,
-            )
-            torchvision.utils.save_image(
-                1
-                - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
-                / 256,
-                "check-test.png",
-                nrow=16,
-            )
-            # exit(0)
-
-            print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}")
-
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -754,6 +729,26 @@ class MNIST(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis
     ):
+        if n_epoch == 0:
+            image_name = os.path.join(result_dir, "fourier.png")
+            torchvision.utils.save_image(
+                0.5
+                - 0.5
+                * self.fourier_basis.reshape(-1, 1, 28, 28)
+                / self.fourier_basis.std(),
+                image_name,
+                nrow=28,
+            )
+
+            image_name = os.path.join(result_dir, "check-train.png")
+            torchvision.utils.save_image(
+                1
+                - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+                / 256,
+                image_name,
+                nrow=16,
+            )
+
         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
         ar_mask = torch.full_like(results, 1)
         masked_inplace_autoregression(