From: François Fleuret Date: Mon, 30 Sep 2024 06:28:37 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fmaster;hp=d95b9b72b0f098b5c955395905a0aff710f553a7;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 37515b5..fe7b49e 100755 --- a/main.py +++ b/main.py @@ -88,6 +88,11 @@ parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") +############################## +# MNIST + +parser.add_argument("--mnist_fourier", action="store_true", default=False) + ############################## # filetask @@ -546,6 +551,7 @@ elif args.task == "mnist": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, + fourier_representation=args.mnist_fourier, device=device, ) @@ -844,7 +850,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): input = input.to(device) bs = model(mygpt.BracketedSequence(input)) - output_ar = bs.x + output = bs.x loss = F.cross_entropy(output.transpose(1, 2), input) diff --git a/tasks.py b/tasks.py index 443419e..9901715 100755 --- a/tasks.py +++ b/tasks.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -633,9 +633,64 @@ class PicoCLVR(Task): ###################################################################### +def generate_2d_fourier_basis(T): + # Create 1D vectors for time/space in both dimensions + t = torch.linspace(0, T - 1, T) + + # Initialize an empty list to hold the basis vectors + basis = [torch.ones(T, T)] # The constant (DC) component + + # Generate cosine and sine terms for both dimensions + for nx in range(1, T // 2 + 1): + for ny in range(1, T // 2 + 1): + # Cosine and sine components in x- and y-directions + cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1) + sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1) + cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0) + sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0) + + # Basis functions in 2D as outer products + basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny) + basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny) + basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny) + basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny) + + # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T) + basis_matrix = torch.stack(basis[: T * T], dim=0) + + return basis_matrix + + class MNIST(Task): + def create_fourier_basis(self): + self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.fourier_basis_inverse = self.fourier_basis.inverse() + y = self.train_input.float() @ self.fourier_basis.t() + self.fourier_range = 4 + self.fourier_mu = y.mean(dim=0, keepdim=True) + self.fourier_std = y.std(dim=0, keepdim=True) + + def fourier_encode(self, x): + y = x.float() @ self.fourier_basis.t() + y = ((y - self.fourier_mu) / self.fourier_std).clamp( + min=-self.fourier_range, max=self.fourier_range + ) + y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long() + return y + + def fourier_decode(self, y): + y = ( + (y / 255.0) * (2 * self.fourier_range) - self.fourier_range + ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device) + return y.float() @ self.fourier_basis_inverse.to(y.device).t() + def __init__( - self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + self, + nb_train_samples, + nb_test_samples, + batch_size, + fourier_representation=True, + device=torch.device("cpu"), ): super().__init__() @@ -648,6 +703,14 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + self.fourier_representation = fourier_representation + + if fourier_representation: + self.create_fourier_basis() + + self.train_input = self.fourier_encode(self.train_input) + self.test_input = self.fourier_encode(self.test_input) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -666,6 +729,26 @@ class MNIST(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): + if n_epoch == 0: + image_name = os.path.join(result_dir, "fourier.png") + torchvision.utils.save_image( + 0.5 + - 0.5 + * self.fourier_basis.reshape(-1, 1, 28, 28) + / self.fourier_basis.std(), + image_name, + nrow=28, + ) + + image_name = os.path.join(result_dir, "check-train.png") + torchvision.utils.save_image( + 1 + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + image_name, + nrow=16, + ) + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) ar_mask = torch.full_like(results, 1) masked_inplace_autoregression( @@ -676,6 +759,10 @@ class MNIST(Task): deterministic_synthesis, device=self.device, ) + + if self.fourier_representation: + results = self.fourier_decode(results) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, diff --git a/turing.py b/turing.py new file mode 100755 index 0000000..2bcdeeb --- /dev/null +++ b/turing.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python + +import torch + + +def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5): + next_state = torch.randint(nb_states, (N, nb_states, nb_symbols)) + next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols)) + next_move = torch.randint(3, (N, nb_states, nb_symbols)) + + all_n = torch.arange(N) + + tape = torch.randint(nb_symbols, (N, tape_size)) + # position = torch.randint(tape_size, (N,)) + # state = torch.randint(nb_states, (N,)) + position = torch.zeros(N, dtype=torch.int64) + state = torch.zeros(N, dtype=torch.int64) + + result = [] + + for _ in range(nb_iter): + result.append(tape.clone()) + current_symbol = tape[all_n, position] + tape[all_n, position] = next_symbol[all_n, state, current_symbol] + position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size + state = next_state[all_n, state, current_symbol] + + result = torch.cat([x[:, None, :] for x in result], dim=1) + + return result + + +###################################################################### + +if __name__ == "__main__": + print("Basic check.") + + tapes = generate_turing_sequences(1, nb_iter=10) + + for i in range(tapes.size(1)): + # print(f"- {i:03d} ------------------------") + # for s, h, r in zip(state, position, tape): + # print("".join([f"{x}" for x in r])) + # print(" " * h + f"^[{s}]") + for r in tapes: + print("".join([f"{x}" for x in r[i]]))