Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 08:59:36 +0000 (10:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 08:59:36 +0000 (10:59 +0200)
tasks.py

index 443419e..4ec410f 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
 
 import torch, torchvision
 
@@ -634,8 +634,23 @@ class PicoCLVR(Task):
 
 
 class MNIST(Task):
+    def fourier_encode(self, x):
+        y = torch.linalg.lstsq(
+            self.fourier_basis.to(x.device).t(), x.float().t()
+        ).solution.t()
+        y = ((y / 255).clamp(min=0, max=1) * 255).long()
+        return y
+
+    def fourier_decode(self, y):
+        return y.float() @ self.fourier_basis.to(y.device)
+
     def __init__(
-        self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        fourier=True,
+        device=torch.device("cpu"),
     ):
         super().__init__()
 
@@ -648,6 +663,30 @@ class MNIST(Task):
         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
 
+        self.fourier = fourier
+
+        if fourier:
+            self.create_fourier_basis()
+
+            # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}")
+
+            self.train_input = self.fourier_encode(self.train_input[:256])
+            self.test_input = self.fourier_encode(self.test_input[:256])
+
+            torchvision.utils.save_image(
+                1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256,
+                "train.png",
+                nrow=16,
+            )
+            torchvision.utils.save_image(
+                1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256,
+                "test.png",
+                nrow=16,
+            )
+            # exit(0)
+
+            print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}")
+
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -676,6 +715,9 @@ class MNIST(Task):
             deterministic_synthesis,
             device=self.device,
         )
+
+        results = self.fourier_decode(results)
+
         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,