# Written by Francois Fleuret <francois@fleuret.org>
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
import torch, torchvision
######################################################################
+def generate_2d_fourier_basis(T):
+ # Create 1D vectors for time/space in both dimensions
+ t = torch.linspace(0, T - 1, T)
+
+ # Initialize an empty list to hold the basis vectors
+ basis = [torch.ones(T, T)] # The constant (DC) component
+
+ # Generate cosine and sine terms for both dimensions
+ for nx in range(1, T // 2 + 1):
+ for ny in range(1, T // 2 + 1):
+ # Cosine and sine components in x- and y-directions
+ cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1)
+ sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1)
+ cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0)
+ sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0)
+
+ # Basis functions in 2D as outer products
+ basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny)
+ basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny)
+ basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny)
+ basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny)
+
+ # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T)
+ basis_matrix = torch.stack(basis[: T * T], dim=0)
+
+ return basis_matrix
+
+
class MNIST(Task):
+ def create_fourier_basis(self):
+ self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
+ self.fourier_basis_inverse = self.fourier_basis.inverse()
+ y = self.train_input.float() @ self.fourier_basis.t()
+ self.fourier_range = 4
+ self.fourier_mu = y.mean(dim=0, keepdim=True)
+ self.fourier_std = y.std(dim=0, keepdim=True)
+
+ def fourier_encode(self, x):
+ y = x.float() @ self.fourier_basis.t()
+ y = ((y - self.fourier_mu) / self.fourier_std).clamp(
+ min=-self.fourier_range, max=self.fourier_range
+ )
+ y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long()
+ return y
+
+ def fourier_decode(self, y):
+ y = (
+ (y / 255.0) * (2 * self.fourier_range) - self.fourier_range
+ ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device)
+ return y.float() @ self.fourier_basis_inverse.to(y.device).t()
+
def __init__(
- self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ fourier_representation=True,
+ device=torch.device("cpu"),
):
super().__init__()
data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
+ self.fourier_representation = fourier_representation
+
+ if fourier_representation:
+ self.create_fourier_basis()
+
+ self.train_input = self.fourier_encode(self.train_input)
+ self.test_input = self.fourier_encode(self.test_input)
+
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
+ if n_epoch == 0:
+ image_name = os.path.join(result_dir, "fourier.png")
+ torchvision.utils.save_image(
+ 0.5
+ - 0.5
+ * self.fourier_basis.reshape(-1, 1, 28, 28)
+ / self.fourier_basis.std(),
+ image_name,
+ nrow=28,
+ )
+
+ image_name = os.path.join(result_dir, "check-train.png")
+ torchvision.utils.save_image(
+ 1
+ - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+ / 256,
+ image_name,
+ nrow=16,
+ )
+
results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
ar_mask = torch.full_like(results, 1)
masked_inplace_autoregression(
deterministic_synthesis,
device=self.device,
)
+
+ if self.fourier_representation:
+ results = self.fourier_decode(results)
+
image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
torchvision.utils.save_image(
1 - results.reshape(-1, 1, 28, 28) / 255.0,
--- /dev/null
+#!/usr/bin/env python
+
+import torch
+
+
+def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5):
+ next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
+ next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
+ next_move = torch.randint(3, (N, nb_states, nb_symbols))
+
+ all_n = torch.arange(N)
+
+ tape = torch.randint(nb_symbols, (N, tape_size))
+ # position = torch.randint(tape_size, (N,))
+ # state = torch.randint(nb_states, (N,))
+ position = torch.zeros(N, dtype=torch.int64)
+ state = torch.zeros(N, dtype=torch.int64)
+
+ result = []
+
+ for _ in range(nb_iter):
+ result.append(tape.clone())
+ current_symbol = tape[all_n, position]
+ tape[all_n, position] = next_symbol[all_n, state, current_symbol]
+ position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size
+ state = next_state[all_n, state, current_symbol]
+
+ result = torch.cat([x[:, None, :] for x in result], dim=1)
+
+ return result
+
+
+######################################################################
+
+if __name__ == "__main__":
+ print("Basic check.")
+
+ tapes = generate_turing_sequences(1, nb_iter=10)
+
+ for i in range(tapes.size(1)):
+ # print(f"- {i:03d} ------------------------")
+ # for s, h, r in zip(state, position, tape):
+ # print("".join([f"{x}" for x in r]))
+ # print(" " * h + f"^[{s}]")
+ for r in tapes:
+ print("".join([f"{x}" for x in r[i]]))