Update. master
authorFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000 (08:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000 (08:28 +0200)
main.py
tasks.py
turing.py [new file with mode: 0755]

diff --git a/main.py b/main.py
index 37515b5..fe7b49e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -88,6 +88,11 @@ parser.add_argument("--resume", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
+##############################
+# MNIST
+
+parser.add_argument("--mnist_fourier", action="store_true", default=False)
+
 ##############################
 # filetask
 
@@ -546,6 +551,7 @@ elif args.task == "mnist":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.physical_batch_size,
+        fourier_representation=args.mnist_fourier,
         device=device,
     )
 
@@ -844,7 +850,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
             input = input.to(device)
 
             bs = model(mygpt.BracketedSequence(input))
-            output_ar = bs.x
+            output = bs.x
 
             loss = F.cross_entropy(output.transpose(1, 2), input)
 
index 443419e..9901715 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
 
 import torch, torchvision
 
@@ -633,9 +633,64 @@ class PicoCLVR(Task):
 ######################################################################
 
 
+def generate_2d_fourier_basis(T):
+    # Create 1D vectors for time/space in both dimensions
+    t = torch.linspace(0, T - 1, T)
+
+    # Initialize an empty list to hold the basis vectors
+    basis = [torch.ones(T, T)]  # The constant (DC) component
+
+    # Generate cosine and sine terms for both dimensions
+    for nx in range(1, T // 2 + 1):
+        for ny in range(1, T // 2 + 1):
+            # Cosine and sine components in x- and y-directions
+            cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1)
+            sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1)
+            cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0)
+            sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0)
+
+            # Basis functions in 2D as outer products
+            basis.append(torch.mm(cos_x, cos_y))  # cos(nx)cos(ny)
+            basis.append(torch.mm(sin_x, sin_y))  # sin(nx)sin(ny)
+            basis.append(torch.mm(cos_x, sin_y))  # cos(nx)sin(ny)
+            basis.append(torch.mm(sin_x, cos_y))  # sin(nx)cos(ny)
+
+    # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T)
+    basis_matrix = torch.stack(basis[: T * T], dim=0)
+
+    return basis_matrix
+
+
 class MNIST(Task):
+    def create_fourier_basis(self):
+        self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
+        self.fourier_basis_inverse = self.fourier_basis.inverse()
+        y = self.train_input.float() @ self.fourier_basis.t()
+        self.fourier_range = 4
+        self.fourier_mu = y.mean(dim=0, keepdim=True)
+        self.fourier_std = y.std(dim=0, keepdim=True)
+
+    def fourier_encode(self, x):
+        y = x.float() @ self.fourier_basis.t()
+        y = ((y - self.fourier_mu) / self.fourier_std).clamp(
+            min=-self.fourier_range, max=self.fourier_range
+        )
+        y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long()
+        return y
+
+    def fourier_decode(self, y):
+        y = (
+            (y / 255.0) * (2 * self.fourier_range) - self.fourier_range
+        ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device)
+        return y.float() @ self.fourier_basis_inverse.to(y.device).t()
+
     def __init__(
-        self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        fourier_representation=True,
+        device=torch.device("cpu"),
     ):
         super().__init__()
 
@@ -648,6 +703,14 @@ class MNIST(Task):
         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
 
+        self.fourier_representation = fourier_representation
+
+        if fourier_representation:
+            self.create_fourier_basis()
+
+            self.train_input = self.fourier_encode(self.train_input)
+            self.test_input = self.fourier_encode(self.test_input)
+
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -666,6 +729,26 @@ class MNIST(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis
     ):
+        if n_epoch == 0:
+            image_name = os.path.join(result_dir, "fourier.png")
+            torchvision.utils.save_image(
+                0.5
+                - 0.5
+                * self.fourier_basis.reshape(-1, 1, 28, 28)
+                / self.fourier_basis.std(),
+                image_name,
+                nrow=28,
+            )
+
+            image_name = os.path.join(result_dir, "check-train.png")
+            torchvision.utils.save_image(
+                1
+                - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+                / 256,
+                image_name,
+                nrow=16,
+            )
+
         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
         ar_mask = torch.full_like(results, 1)
         masked_inplace_autoregression(
@@ -676,6 +759,10 @@ class MNIST(Task):
             deterministic_synthesis,
             device=self.device,
         )
+
+        if self.fourier_representation:
+            results = self.fourier_decode(results)
+
         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
diff --git a/turing.py b/turing.py
new file mode 100755 (executable)
index 0000000..2bcdeeb
--- /dev/null
+++ b/turing.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+
+import torch
+
+
+def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5):
+    next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
+    next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
+    next_move = torch.randint(3, (N, nb_states, nb_symbols))
+
+    all_n = torch.arange(N)
+
+    tape = torch.randint(nb_symbols, (N, tape_size))
+    # position = torch.randint(tape_size, (N,))
+    # state = torch.randint(nb_states, (N,))
+    position = torch.zeros(N, dtype=torch.int64)
+    state = torch.zeros(N, dtype=torch.int64)
+
+    result = []
+
+    for _ in range(nb_iter):
+        result.append(tape.clone())
+        current_symbol = tape[all_n, position]
+        tape[all_n, position] = next_symbol[all_n, state, current_symbol]
+        position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size
+        state = next_state[all_n, state, current_symbol]
+
+    result = torch.cat([x[:, None, :] for x in result], dim=1)
+
+    return result
+
+
+######################################################################
+
+if __name__ == "__main__":
+    print("Basic check.")
+
+    tapes = generate_turing_sequences(1, nb_iter=10)
+
+    for i in range(tapes.size(1)):
+        # print(f"- {i:03d} ------------------------")
+        # for s, h, r in zip(state, position, tape):
+        # print("".join([f"{x}" for x in r]))
+        # print(" " * h + f"^[{s}]")
+        for r in tapes:
+            print("".join([f"{x}" for x in r[i]]))