Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:16:20 +0000 (08:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:16:20 +0000 (08:16 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 3ff64b7..fe7b49e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -88,6 +88,11 @@ parser.add_argument("--resume", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
+##############################
+# MNIST
+
+parser.add_argument("--mnist_fourier", action="store_true", default=False)
+
 ##############################
 # filetask
 
@@ -546,6 +551,7 @@ elif args.task == "mnist":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
+        fourier_representation=args.mnist_fourier,
         device=device,
     )
 
index 9100632..5f3258b 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -662,51 +662,41 @@ def generate_2d_fourier_basis(T):
 
 
 class MNIST(Task):
-    def create_global_basis(self):
-        import numpy as np
-
-        # self.global_basis = torch.randn(784, 784)
-
-        self.global_basis = generate_2d_fourier_basis(T=28).flatten(1)
-        self.global_basis_inverse = self.global_basis.inverse()
+    def create_fourier_basis(self):
+        self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
+        self.fourier_basis_inverse = self.fourier_basis.inverse()
 
         torchvision.utils.save_image(
-            1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(),
+            1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(),
             "fourier.png",
             nrow=28,
         )
 
-        y = self.train_input.float() @ self.global_basis.t()
-        self.range = 4
-        self.global_mu = y.mean(dim=0, keepdim=True)
-        self.global_std = y.std(dim=0, keepdim=True)
-
-        # for k in range(25):
-        # X = self.global_encode(self.train_input).float()
-        # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item())
-        # sys.stdout.flush()
-        # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution
-
-    def global_encode(self, x):
-        y = x.float() @ self.global_basis.t()
-        y = ((y - self.global_mu) / self.global_std).clamp(
-            min=-self.range, max=self.range
+        y = self.train_input.float() @ self.fourier_basis.t()
+        self.fourier_range = 4
+        self.fourier_mu = y.mean(dim=0, keepdim=True)
+        self.fourier_std = y.std(dim=0, keepdim=True)
+
+    def fourier_encode(self, x):
+        y = x.float() @ self.fourier_basis.t()
+        y = ((y - self.fourier_mu) / self.fourier_std).clamp(
+            min=-self.fourier_range, max=self.fourier_range
         )
-        y = (((y + self.range) / (2 * self.range)) * 255).long()
+        y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long()
         return y
 
-    def global_decode(self, y):
-        y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to(
-            y.device
-        ) + self.global_mu.to(y.device)
-        return y.float() @ self.global_basis_inverse.to(y.device).t()
+    def fourier_decode(self, y):
+        y = (
+            (y / 255.0) * (2 * self.fourier_range) - self.fourier_range
+        ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device)
+        return y.float() @ self.fourier_basis_inverse.to(y.device).t()
 
     def __init__(
         self,
         nb_train_samples,
         nb_test_samples,
         batch_size,
-        global_representation=True,
+        fourier_representation=True,
         device=torch.device("cpu"),
     ):
         super().__init__()
@@ -720,24 +710,24 @@ class MNIST(Task):
         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
 
-        self.global_representation = global_representation
+        self.fourier_representation = fourier_representation
 
-        if global_representation:
-            self.create_global_basis()
+        if fourier_representation:
+            self.create_fourier_basis()
 
-            self.train_input = self.global_encode(self.train_input)
-            self.test_input = self.global_encode(self.test_input)
+            self.train_input = self.fourier_encode(self.train_input)
+            self.test_input = self.fourier_encode(self.test_input)
 
             torchvision.utils.save_image(
                 1
-                - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+                - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
                 / 256,
                 "check-train.png",
                 nrow=16,
             )
             torchvision.utils.save_image(
                 1
-                - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
+                - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
                 / 256,
                 "check-test.png",
                 nrow=16,
@@ -775,8 +765,8 @@ class MNIST(Task):
             device=self.device,
         )
 
-        if self.global_representation:
-            results = self.global_decode(results)
+        if self.fourier_representation:
+            results = self.fourier_decode(results)
 
         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(