From 22415499c0a91922e51f9e2cade009fd404351dc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 14 Jun 2024 11:34:41 +0200 Subject: [PATCH 1/8] Update. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 37515b5..3ff64b7 100755 --- a/main.py +++ b/main.py @@ -844,7 +844,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): input = input.to(device) bs = model(mygpt.BracketedSequence(input)) - output_ar = bs.x + output = bs.x loss = F.cross_entropy(output.transpose(1, 2), input) -- 2.39.5 From b6228999b93968b7362b70b1b570e622a954b805 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jun 2024 15:41:51 +0200 Subject: [PATCH 2/8] Update. --- turing.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100755 turing.py diff --git a/turing.py b/turing.py new file mode 100755 index 0000000..66c7f03 --- /dev/null +++ b/turing.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python + +import torch + + +def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size=5): + next_state = torch.randint(nb_states, (N, nb_states, nb_symbols)) + next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols)) + next_move = torch.randint(3, (N, nb_states, nb_symbols)) + + all_n = torch.arange(N) + + tape = torch.randint(nb_symbols, (N, tape_size)) + position = torch.randint(tape_size, (N,)) + state = torch.randint(nb_states, (N,)) + + result = [] + + for _ in range(nb_iter): + result.append(tape) + current_symbol = tape[all_n, position] + tape[all_n, position] = next_symbol[all_n, state, current_symbol] + position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size + state = next_state[all_n, state, current_symbol] + + result = torch.cat([x[:, None, :] for x in result], dim=1) + + return result + + +###################################################################### + +if __name__ == "__main__": + print("Basic check.") + + tapes = generate_turing_sequences(5) + + for i in range(tapes.size(1)): + print(f"- {i:03d} ------------------------") + # for s, h, r in zip(state, position, tape): + # print("".join([f"{x}" for x in r])) + # print(" " * h + f"^[{s}]") + for r in tapes: + print("".join([f"{x}" for x in r[i]])) -- 2.39.5 From cf94b49d085ec05e1053b49b7e796afa3f593a28 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jun 2024 15:53:17 +0200 Subject: [PATCH 3/8] Update. --- turing.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/turing.py b/turing.py index 66c7f03..2bcdeeb 100755 --- a/turing.py +++ b/turing.py @@ -3,7 +3,7 @@ import torch -def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size=5): +def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5): next_state = torch.randint(nb_states, (N, nb_states, nb_symbols)) next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols)) next_move = torch.randint(3, (N, nb_states, nb_symbols)) @@ -11,13 +11,15 @@ def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size all_n = torch.arange(N) tape = torch.randint(nb_symbols, (N, tape_size)) - position = torch.randint(tape_size, (N,)) - state = torch.randint(nb_states, (N,)) + # position = torch.randint(tape_size, (N,)) + # state = torch.randint(nb_states, (N,)) + position = torch.zeros(N, dtype=torch.int64) + state = torch.zeros(N, dtype=torch.int64) result = [] for _ in range(nb_iter): - result.append(tape) + result.append(tape.clone()) current_symbol = tape[all_n, position] tape[all_n, position] = next_symbol[all_n, state, current_symbol] position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size @@ -33,10 +35,10 @@ def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size if __name__ == "__main__": print("Basic check.") - tapes = generate_turing_sequences(5) + tapes = generate_turing_sequences(1, nb_iter=10) for i in range(tapes.size(1)): - print(f"- {i:03d} ------------------------") + # print(f"- {i:03d} ------------------------") # for s, h, r in zip(state, position, tape): # print("".join([f"{x}" for x in r])) # print(" " * h + f"^[{s}]") -- 2.39.5 From 5f1df36dae9888175a44f2c0d9651f67a457904c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 10:59:36 +0200 Subject: [PATCH 4/8] Update. --- tasks.py | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 443419e..4ec410f 100755 --- a/tasks.py +++ b/tasks.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -634,8 +634,23 @@ class PicoCLVR(Task): class MNIST(Task): + def fourier_encode(self, x): + y = torch.linalg.lstsq( + self.fourier_basis.to(x.device).t(), x.float().t() + ).solution.t() + y = ((y / 255).clamp(min=0, max=1) * 255).long() + return y + + def fourier_decode(self, y): + return y.float() @ self.fourier_basis.to(y.device) + def __init__( - self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + self, + nb_train_samples, + nb_test_samples, + batch_size, + fourier=True, + device=torch.device("cpu"), ): super().__init__() @@ -648,6 +663,30 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + self.fourier = fourier + + if fourier: + self.create_fourier_basis() + + # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}") + + self.train_input = self.fourier_encode(self.train_input[:256]) + self.test_input = self.fourier_encode(self.test_input[:256]) + + torchvision.utils.save_image( + 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256, + "train.png", + nrow=16, + ) + torchvision.utils.save_image( + 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256, + "test.png", + nrow=16, + ) + # exit(0) + + print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}") + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -676,6 +715,9 @@ class MNIST(Task): deterministic_synthesis, device=self.device, ) + + results = self.fourier_decode(results) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, -- 2.39.5 From 80c23832f3c8a29ef684ec68e580c8af68c43414 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 23:06:25 +0200 Subject: [PATCH 5/8] Update. --- tasks.py | 100 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/tasks.py b/tasks.py index 4ec410f..450a495 100755 --- a/tasks.py +++ b/tasks.py @@ -633,23 +633,80 @@ class PicoCLVR(Task): ###################################################################### +def generate_2d_fourier_basis(T): + # Create 1D vectors for time/space in both dimensions + t = torch.linspace(0, T - 1, T) + + # Initialize an empty list to hold the basis vectors + basis = [torch.ones(T, T)] # The constant (DC) component + + # Generate cosine and sine terms for both dimensions + for nx in range(1, T // 2 + 1): + for ny in range(1, T // 2 + 1): + # Cosine and sine components in x- and y-directions + cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1) + sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1) + cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0) + sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0) + + # Basis functions in 2D as outer products + basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny) + basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny) + basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny) + basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny) + + # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T) + basis_matrix = torch.stack(basis[: T * T], dim=0) + + return basis_matrix + + class MNIST(Task): - def fourier_encode(self, x): - y = torch.linalg.lstsq( - self.fourier_basis.to(x.device).t(), x.float().t() - ).solution.t() - y = ((y / 255).clamp(min=0, max=1) * 255).long() + def create_global_basis(self): + import numpy as np + + # self.global_basis = torch.randn(784, 784) + + self.global_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.global_basis_inverse = self.global_basis.inverse() + + torchvision.utils.save_image( + 1 - self.global_basis / self.global_basis.std(), + "fourier.png", + nrow=28, + ) + + y = self.train_input.float() @ self.global_basis.t() + self.range = 4 + self.global_mu = y.mean(dim=0, keepdim=True) + self.global_std = y.std(dim=0, keepdim=True) + + # for k in range(25): + # X = self.global_encode(self.train_input).float() + # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item()) + # sys.stdout.flush() + # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution + + def global_encode(self, x): + y = x.float() @ self.global_basis.t() + y = ((y - self.global_mu) / self.global_std).clamp( + min=-self.range, max=self.range + ) + y = (((y + self.range) / (2 * self.range)) * 255).long() return y - def fourier_decode(self, y): - return y.float() @ self.fourier_basis.to(y.device) + def global_decode(self, y): + y = ( + (y / 255.0) * (2 * self.range) - self.range + ) * self.global_std + self.global_mu + return y.float() @ self.global_basis_inverse.to(y.device).t() def __init__( self, nb_train_samples, nb_test_samples, batch_size, - fourier=True, + global_representation=True, device=torch.device("cpu"), ): super().__init__() @@ -663,24 +720,26 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() - self.fourier = fourier - - if fourier: - self.create_fourier_basis() + self.global_representation = global_representation - # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}") + if global_representation: + self.create_global_basis() - self.train_input = self.fourier_encode(self.train_input[:256]) - self.test_input = self.fourier_encode(self.test_input[:256]) + self.train_input = self.global_encode(self.train_input) + self.test_input = self.global_encode(self.test_input) torchvision.utils.save_image( - 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256, - "train.png", + 1 + - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + "check-train.png", nrow=16, ) torchvision.utils.save_image( - 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256, - "test.png", + 1 + - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) + / 256, + "check-test.png", nrow=16, ) # exit(0) @@ -716,7 +775,8 @@ class MNIST(Task): device=self.device, ) - results = self.fourier_decode(results) + if self.global_representation: + results = self.global_decode(results) image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( -- 2.39.5 From a5a653bbebcd89b616918f44543c85140ecdaa33 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 29 Sep 2024 23:55:01 +0200 Subject: [PATCH 6/8] Update. --- tasks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tasks.py b/tasks.py index 450a495..9100632 100755 --- a/tasks.py +++ b/tasks.py @@ -671,7 +671,7 @@ class MNIST(Task): self.global_basis_inverse = self.global_basis.inverse() torchvision.utils.save_image( - 1 - self.global_basis / self.global_basis.std(), + 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(), "fourier.png", nrow=28, ) @@ -696,9 +696,9 @@ class MNIST(Task): return y def global_decode(self, y): - y = ( - (y / 255.0) * (2 * self.range) - self.range - ) * self.global_std + self.global_mu + y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to( + y.device + ) + self.global_mu.to(y.device) return y.float() @ self.global_basis_inverse.to(y.device).t() def __init__( -- 2.39.5 From 0057bf10615d045a043014cfe1bcb01ba9c3871d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 30 Sep 2024 08:16:20 +0200 Subject: [PATCH 7/8] Update. --- main.py | 6 +++++ tasks.py | 68 ++++++++++++++++++++++++-------------------------------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/main.py b/main.py index 3ff64b7..fe7b49e 100755 --- a/main.py +++ b/main.py @@ -88,6 +88,11 @@ parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") +############################## +# MNIST + +parser.add_argument("--mnist_fourier", action="store_true", default=False) + ############################## # filetask @@ -546,6 +551,7 @@ elif args.task == "mnist": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, + fourier_representation=args.mnist_fourier, device=device, ) diff --git a/tasks.py b/tasks.py index 9100632..5f3258b 100755 --- a/tasks.py +++ b/tasks.py @@ -662,51 +662,41 @@ def generate_2d_fourier_basis(T): class MNIST(Task): - def create_global_basis(self): - import numpy as np - - # self.global_basis = torch.randn(784, 784) - - self.global_basis = generate_2d_fourier_basis(T=28).flatten(1) - self.global_basis_inverse = self.global_basis.inverse() + def create_fourier_basis(self): + self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.fourier_basis_inverse = self.fourier_basis.inverse() torchvision.utils.save_image( - 1 - self.global_basis.reshape(-1, 1, 28, 28) / self.global_basis.std(), + 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(), "fourier.png", nrow=28, ) - y = self.train_input.float() @ self.global_basis.t() - self.range = 4 - self.global_mu = y.mean(dim=0, keepdim=True) - self.global_std = y.std(dim=0, keepdim=True) - - # for k in range(25): - # X = self.global_encode(self.train_input).float() - # print(k, (self.global_decode(X) - self.train_input).pow(2).mean().item()) - # sys.stdout.flush() - # self.global_basis = torch.linalg.lstsq(X, self.train_input.float()).solution - - def global_encode(self, x): - y = x.float() @ self.global_basis.t() - y = ((y - self.global_mu) / self.global_std).clamp( - min=-self.range, max=self.range + y = self.train_input.float() @ self.fourier_basis.t() + self.fourier_range = 4 + self.fourier_mu = y.mean(dim=0, keepdim=True) + self.fourier_std = y.std(dim=0, keepdim=True) + + def fourier_encode(self, x): + y = x.float() @ self.fourier_basis.t() + y = ((y - self.fourier_mu) / self.fourier_std).clamp( + min=-self.fourier_range, max=self.fourier_range ) - y = (((y + self.range) / (2 * self.range)) * 255).long() + y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long() return y - def global_decode(self, y): - y = ((y / 255.0) * (2 * self.range) - self.range) * self.global_std.to( - y.device - ) + self.global_mu.to(y.device) - return y.float() @ self.global_basis_inverse.to(y.device).t() + def fourier_decode(self, y): + y = ( + (y / 255.0) * (2 * self.fourier_range) - self.fourier_range + ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device) + return y.float() @ self.fourier_basis_inverse.to(y.device).t() def __init__( self, nb_train_samples, nb_test_samples, batch_size, - global_representation=True, + fourier_representation=True, device=torch.device("cpu"), ): super().__init__() @@ -720,24 +710,24 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() - self.global_representation = global_representation + self.fourier_representation = fourier_representation - if global_representation: - self.create_global_basis() + if fourier_representation: + self.create_fourier_basis() - self.train_input = self.global_encode(self.train_input) - self.test_input = self.global_encode(self.test_input) + self.train_input = self.fourier_encode(self.train_input) + self.test_input = self.fourier_encode(self.test_input) torchvision.utils.save_image( 1 - - self.global_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) / 256, "check-train.png", nrow=16, ) torchvision.utils.save_image( 1 - - self.global_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) + - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) / 256, "check-test.png", nrow=16, @@ -775,8 +765,8 @@ class MNIST(Task): device=self.device, ) - if self.global_representation: - results = self.global_decode(results) + if self.fourier_representation: + results = self.fourier_decode(results) image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( -- 2.39.5 From 3882ebe7fdc72ac5cd3b39bc6d5d4c4924a20a9e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 30 Sep 2024 08:28:37 +0200 Subject: [PATCH 8/8] Update. --- tasks.py | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tasks.py b/tasks.py index 5f3258b..9901715 100755 --- a/tasks.py +++ b/tasks.py @@ -665,13 +665,6 @@ class MNIST(Task): def create_fourier_basis(self): self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) self.fourier_basis_inverse = self.fourier_basis.inverse() - - torchvision.utils.save_image( - 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(), - "fourier.png", - nrow=28, - ) - y = self.train_input.float() @ self.fourier_basis.t() self.fourier_range = 4 self.fourier_mu = y.mean(dim=0, keepdim=True) @@ -718,24 +711,6 @@ class MNIST(Task): self.train_input = self.fourier_encode(self.train_input) self.test_input = self.fourier_encode(self.test_input) - torchvision.utils.save_image( - 1 - - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) - / 256, - "check-train.png", - nrow=16, - ) - torchvision.utils.save_image( - 1 - - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28) - / 256, - "check-test.png", - nrow=16, - ) - # exit(0) - - print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}") - def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -754,6 +729,26 @@ class MNIST(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): + if n_epoch == 0: + image_name = os.path.join(result_dir, "fourier.png") + torchvision.utils.save_image( + 0.5 + - 0.5 + * self.fourier_basis.reshape(-1, 1, 28, 28) + / self.fourier_basis.std(), + image_name, + nrow=28, + ) + + image_name = os.path.join(result_dir, "check-train.png") + torchvision.utils.save_image( + 1 + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + image_name, + nrow=16, + ) + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) ar_mask = torch.full_like(results, 1) masked_inplace_autoregression( -- 2.39.5