Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 21:06:25 +0000 (23:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 29 Sep 2024 21:06:25 +0000 (23:06 +0200)
tasks.py

index 4ec410f..450a495 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -633,23 +633,80 @@ class PicoCLVR(Task):
 ######################################################################
 
 
+def generate_2d_fourier_basis(T):
+    # Create 1D vectors for time/space in both dimensions
+    t = torch.linspace(0, T - 1, T)
+
+    # Initialize an empty list to hold the basis vectors
+    basis = [torch.ones(T, T)]  # The constant (DC) component
+
+    # Generate cosine and sine terms for both dimensions
+    for nx in range(1, T // 2 + 1):
+        for ny in range(1, T // 2 + 1):
+            # Cosine and sine components in x- and y-directions
+            cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1)
+            sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1)
+            cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0)
+            sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0)
+
+            # Basis functions in 2D as outer products
+            basis.append(torch.mm(cos_x, cos_y))  # cos(nx)cos(ny)
+            basis.append(torch.mm(sin_x, sin_y))  # sin(nx)sin(ny)
+            basis.append(torch.mm(cos_x, sin_y))  # cos(nx)sin(ny)
+            basis.append(torch.mm(sin_x, cos_y))  # sin(nx)cos(ny)
+
+    # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T)
+    basis_matrix = torch.stack(basis[: T * T], dim=0)
+
+    return basis_matrix
+
+
 class MNIST(Task):
-    def fourier_encode(self, x):
-        y = torch.linalg.lstsq(
-            self.fourier_basis.to(x.device).t(), x.float().t()
-        ).solution.t()
-        y = ((y / 255).clamp(min=0, max=1) * 255).long()
+    def create_global_basis(self):
+        import numpy as np
+
+        # self.global_basis = torch.randn(784, 784)
+
+        self.global_basis = generate_2d_fourier_basis(T=28).flatten(1)
+        self.global_basis_inverse = self.global_basis.inverse()
+
+        torchvision.utils.save_image(
+            1 - self.global_basis / self.global_basis.std(),
+            "fourier.png",
+            nrow=28,
+        )
+
+        y = self.train_input.float() @ self.global_basis.t()
+        self.range = 4
+        self.global_mu = y.mean(dim=0, keepdim=True)
+        self.global_std = y.std(dim=0, keepdim=True)
+
+        # for k in range(25):
+        # X = self.global_encode(self.train_input).float()
+        # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item())
+        # sys.stdout.flush()
+        # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution
+
+    def global_encode(self, x):
+        y = x.float() @ self.global_basis.t()
+        y = ((y - self.global_mu) / self.global_std).clamp(
+            min=-self.range, max=self.range
+        )
+        y = (((y + self.range) / (2 * self.range)) * 255).long()
         return y
 
-    def fourier_decode(self, y):
-        return y.float() @ self.fourier_basis.to(y.device)
+    def global_decode(self, y):
+        y = (
+            (y / 255.0) * (2 * self.range) - self.range
+        ) * self.global_std + self.global_mu
+        return y.float() @ self.global_basis_inverse.to(y.device).t()
 
     def __init__(
         self,
         nb_train_samples,
         nb_test_samples,
         batch_size,
-        fourier=True,
+        global_representation=True,
         device=torch.device("cpu"),
     ):
         super().__init__()
@@ -663,24 +720,26 @@ class MNIST(Task):
         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
 
-        self.fourier = fourier
-
-        if fourier:
-            self.create_fourier_basis()
+        self.global_representation = global_representation
 
-            # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}")
+        if global_representation:
+            self.create_global_basis()
 
-            self.train_input = self.fourier_encode(self.train_input[:256])
-            self.test_input = self.fourier_encode(self.test_input[:256])
+            self.train_input = self.global_encode(self.train_input)
+            self.test_input = self.global_encode(self.test_input)
 
             torchvision.utils.save_image(
-                1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256,
-                "train.png",
+                1
+                - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+                / 256,
+                "check-train.png",
                 nrow=16,
             )
             torchvision.utils.save_image(
-                1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256,
-                "test.png",
+                1
+                - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
+                / 256,
+                "check-test.png",
                 nrow=16,
             )
             # exit(0)
@@ -716,7 +775,8 @@ class MNIST(Task):
             device=self.device,
         )
 
-        results = self.fourier_decode(results)
+        if self.global_representation:
+            results = self.global_decode(results)
 
         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(