From e3a8032a070175ece08fc79c77312d5f2f59150e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 14:25:45 +0200 Subject: [PATCH 01/16] Update. --- world.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/world.py b/world.py index da7de75..64c7434 100755 --- a/world.py +++ b/world.py @@ -169,7 +169,7 @@ def train_encoder( train_loss = F.cross_entropy(output, input) if lambda_entropy > 0: - loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) acc_train_loss += train_loss.item() * input.size(0) @@ -439,26 +439,21 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - # 10000, 1000, - 100, - 100, - nb_epochs=2, + 25000, 1000, + nb_epochs=10, mode="first_last", nb_steps=20, ) - input = test_input[:64] + input = test_input[:256] seq = frame2seq(input) - - print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}") - output = seq2frame(seq) torchvision.utils.save_image( - input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8 + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 ) torchvision.utils.save_image( - output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8 + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 ) -- 2.20.1 From 3dea181a5903a0e577e4830c66405b40f2a2df1d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 18:48:43 +0200 Subject: [PATCH 02/16] Update. --- tasks.py | 20 +++++++++++++++++--- world.py | 19 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tasks.py b/tasks.py index 8b57cb2..5583fc8 100755 --- a/tasks.py +++ b/tasks.py @@ -73,8 +73,12 @@ class Problem: class ProblemByheart(Problem): def __init__(self): - pass + nb_seq, len_prompt, len_result = 100, 5, 5 + self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq[:,len_prompt]=-1 + def generate_sequences(self, nb): + return self.seq[torch.randint(self.seq.size(0), (nb,))] class SandBox(Task): def __init__( @@ -89,13 +93,23 @@ class SandBox(Task): self.batch_size = batch_size + problems = [ ProblemByheart() ] + nb_common_codes = 100 + def generate_sequences(nb_samples): problem_indexes = torch.randint(len(problems), (nb_samples,)) nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") + all_seq = [] + for nb, p in zip(nb_samples_per_problem,problems): + all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + return all_seq + + train_seq = generate_sequences(nb_train_samples) + test_seq = generate_sequences(nb_test_samples) - self.train_input = generate_sequences(nb_train_samples) - self.test_input = generate_sequences(nb_test_samples) + for strain, stest in zip(train_seq, test_seq): + s = torch.cat((strain,stest),0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 diff --git a/world.py b/world.py index 64c7434..3d6abbe 100755 --- a/world.py +++ b/world.py @@ -61,6 +61,19 @@ class SignSTE(nn.Module): else: return s +class DiscreteSampler2d(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + s = (x >= x.max(-3,keepdim=True).values).float() + + if self.training: + u = x.softmax(dim=-3) + return s + u - u.detach() + else: + return s + def loss_H(binary_logits, h_threshold=1): p = binary_logits.sigmoid().mean(0) @@ -159,7 +172,7 @@ def train_encoder( for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): input = input.to(device) z = encoder(input) - zq = z if k < 2 else quantizer(z) + zq = quantizer(z) output = decoder(zq) output = output.reshape( @@ -182,7 +195,7 @@ def train_encoder( for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"): input = input.to(device) z = encoder(input) - zq = z if k < 1 else quantizer(z) + zq = quantizer(z) output = decoder(zq) output = output.reshape( @@ -440,7 +453,7 @@ if __name__ == "__main__": seq2frame, ) = create_data_and_processors( 25000, 1000, - nb_epochs=10, + nb_epochs=5, mode="first_last", nb_steps=20, ) -- 2.20.1 From 0f580d4facb4b4b485d0a38d62d06c0639715b77 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 22:51:28 +0200 Subject: [PATCH 03/16] Update. --- ffutils.py | 108 +++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 1 + tasks.py | 9 +++-- world.py | 8 +++- 4 files changed, 120 insertions(+), 6 deletions(-) create mode 100755 ffutils.py diff --git a/ffutils.py b/ffutils.py new file mode 100755 index 0000000..45f44d8 --- /dev/null +++ b/ffutils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch +import sys, contextlib + +import torch +from torch import Tensor + +###################################################################### + + +@contextlib.contextmanager +def evaluation(*models): + with torch.inference_mode(): + t = [(m, m.training) for m in models] + for m in models: + m.train(False) + yield + for m, u in t: + m.train(u) + + +###################################################################### + +from torch.utils._python_dispatch import TorchDispatchMode + + +def hasNaN(x): + if torch.is_tensor(x): + return x.isnan().max() + else: + try: + return any([hasNaN(y) for y in x]) + except TypeError: + return False + + +class NaNDetect(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + + if hasNaN(res): + raise RuntimeError( + f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" + ) + return res + + +###################################################################### + + +def exception_hook(exc_type, exc_value, tb): + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. + + """ + + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" + + while tb: + print("--------------------------------------------------\n") + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + print(f' File "{filename}", line {line_no}, in {name}') + print(open(filename, "r").readlines()[line_no - 1]) + + if exc_type in {RuntimeError, ValueError, IndexError, TypeError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") + + print() + tb = tb.tb_next + + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") + + +def activate_tensorstack(): + sys.excepthook = exception_hook + + +###################################################################### + +if __name__ == "__main__": + import torch + + def dummy(a, b): + print(a @ b) + + def blah(a, b): + c = b + b + dummy(a, c) + + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm) diff --git a/main.py b/main.py index 69ee58f..e18887b 100755 --- a/main.py +++ b/main.py @@ -14,6 +14,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +import ffutils import mygpt, tasks ###################################################################### diff --git a/tasks.py b/tasks.py index 5583fc8..9cd06ae 100755 --- a/tasks.py +++ b/tasks.py @@ -75,11 +75,12 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:,len_prompt]=-1 + self.seq[:, len_prompt] = -1 def generate_sequences(self, nb): return self.seq[torch.randint(self.seq.size(0), (nb,))] + class SandBox(Task): def __init__( self, @@ -93,7 +94,7 @@ class SandBox(Task): self.batch_size = batch_size - problems = [ ProblemByheart() ] + problems = [ProblemByheart()] nb_common_codes = 100 def generate_sequences(nb_samples): @@ -101,7 +102,7 @@ class SandBox(Task): nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") all_seq = [] - for nb, p in zip(nb_samples_per_problem,problems): + for nb, p in zip(nb_samples_per_problem, problems): all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) return all_seq @@ -109,7 +110,7 @@ class SandBox(Task): test_seq = generate_sequences(nb_test_samples) for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain,stest),0) + s = torch.cat((strain, stest), 0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 diff --git a/world.py b/world.py index 3d6abbe..b35a08e 100755 --- a/world.py +++ b/world.py @@ -85,9 +85,9 @@ def loss_H(binary_logits, h_threshold=1): def train_encoder( train_input, test_input, - depth=2, + depth, + nb_bits_per_token, dim_hidden=48, - nb_bits_per_token=8, lambda_entropy=0.0, lr_start=1e-3, lr_end=1e-4, @@ -366,6 +366,8 @@ def create_data_and_processors( nb_test_samples, mode, nb_steps, + depth=3, + nb_bits_per_token=8, nb_epochs=10, device=torch.device("cpu"), device_storage=torch.device("cpu"), @@ -388,6 +390,8 @@ def create_data_and_processors( encoder, quantizer, decoder = train_encoder( train_input, test_input, + depth=depth, + nb_bits_per_token=nb_bits_per_token, lambda_entropy=1.0, nb_epochs=nb_epochs, logger=logger, -- 2.20.1 From 5366dfd7bd57ec3298d1030f7d5327ff26bc5aad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 08:44:21 +0200 Subject: [PATCH 04/16] Update. --- main.py | 1 + tasks.py | 80 ++++++++++++++++++++++++++++++++++++++------------------ world.py | 11 +++++--- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index e18887b..3be3d55 100755 --- a/main.py +++ b/main.py @@ -266,6 +266,7 @@ picoclvr_pruner_eval = ( if args.task == "sandbox": task = tasks.SandBox( + tasks.ProblemByheart(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/tasks.py b/tasks.py index 9cd06ae..eef84af 100755 --- a/tasks.py +++ b/tasks.py @@ -64,10 +64,10 @@ class Task: class Problem: - def generate(nb): + def generate_sequences(self, nb): pass - def perf(seq, logger): + def log_performance(self, sequences, logger): pass @@ -75,15 +75,33 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = -1 + self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): - return self.seq[torch.randint(self.seq.size(0), (nb,))] - + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences==10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + # problems = [ProblemByheart()] + # nb_common_codes = 100 + + # def generate_sequences(nb_samples): + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq + + # for strain, stest in zip(train_seq, test_seq): + # s = torch.cat((strain, stest), 0) class SandBox(Task): def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -93,24 +111,10 @@ class SandBox(Task): super().__init__() self.batch_size = batch_size + self.device = device - problems = [ProblemByheart()] - nb_common_codes = 100 - - def generate_sequences(nb_samples): - problem_indexes = torch.randint(len(problems), (nb_samples,)) - nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - print(f"{nb_samples_per_problem}") - all_seq = [] - for nb, p in zip(nb_samples_per_problem, problems): - all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - return all_seq - - train_seq = generate_sequences(nb_train_samples) - test_seq = generate_sequences(nb_test_samples) - - for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain, stest), 0) + self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -132,11 +136,35 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - # logger( - # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - # ) - pass + def compute_accuracy(input, ar_mask): + result = input.clone() * (1-ar_mask) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + nb_total = ar_mask.sum().item() + nb_correct = ((result==input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ###################################################################### diff --git a/world.py b/world.py index b35a08e..12c6553 100755 --- a/world.py +++ b/world.py @@ -96,8 +96,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - if logger is None: - logger = lambda s: print(s) mu, std = train_input.float().mean(), train_input.float().std() @@ -157,7 +155,7 @@ def train_encoder( nb_parameters = sum(p.numel() for p in model.parameters()) - logger(f"nb_parameters {nb_parameters}") + logger(f"vqae nb_parameters {nb_parameters}") model.to(device) @@ -209,7 +207,7 @@ def train_encoder( train_loss = acc_train_loss / train_input.size(0) test_loss = acc_test_loss / test_input.size(0) - logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") + logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") sys.stdout.flush() return encoder, quantizer, decoder @@ -378,6 +376,9 @@ def create_data_and_processors( if mode == "first_last": steps = [True] + [False] * (nb_steps + 1) + [True] + if logger is None: + logger = lambda s: print(s) + train_input, train_actions = generate_episodes(nb_train_samples, steps) train_input, train_actions = train_input.to(device_storage), train_actions.to( device_storage @@ -405,6 +406,8 @@ def create_data_and_processors( pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :] z_h, z_w = z.size(2), z.size(3) + logger(f"vqae input {train_input[0].size()} output {z[0].size()}") + def frame2seq(input, batch_size=25): seq = [] p = pow2.to(device) -- 2.20.1 From a2ffcd9b27aa0f3cc0b56090a32e88b73dfa0a54 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 08:49:15 +0200 Subject: [PATCH 05/16] Update. --- tasks.py | 47 ++++++++++++++++++++++++++++++++--------------- world.py | 7 ++++--- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/tasks.py b/tasks.py index eef84af..fb85576 100755 --- a/tasks.py +++ b/tasks.py @@ -79,7 +79,7 @@ class ProblemByheart(Problem): def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] - ar_mask = (sequences==10).long() + ar_mask = (sequences == 10).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask @@ -87,16 +87,17 @@ class ProblemByheart(Problem): # nb_common_codes = 100 # def generate_sequences(nb_samples): - # problem_indexes = torch.randint(len(problems), (nb_samples,)) - # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - # print(f"{nb_samples_per_problem}") - # all_seq = [] - # for nb, p in zip(nb_samples_per_problem, problems): - # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - # return all_seq + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq # for strain, stest in zip(train_seq, test_seq): - # s = torch.cat((strain, stest), 0) + # s = torch.cat((strain, stest), 0) + class SandBox(Task): def __init__( @@ -107,17 +108,29 @@ class SandBox(Task): batch_size, logger=None, device=torch.device("cpu"), + max_nb_codes=1024, ): super().__init__() self.batch_size = batch_size self.device = device - self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.train_input, self.train_ar_mask = problem.generate_sequences( + nb_train_samples + ) self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + # A bit of paranoia never hurts + assert ( + self.nb_codes <= max_nb_codes + and self.train_input.min() >= 0 + and self.test_input.min() >= 0 + and tuple(self.train_ar_mask.unique()) == (0, 1) + and tuple(self.test_ar_mask.unique()) == (0, 1) + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -136,9 +149,8 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - def compute_accuracy(input, ar_mask): - result = input.clone() * (1-ar_mask) + result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( model, self.batch_size, @@ -150,22 +162,27 @@ class SandBox(Task): ) nb_total = ar_mask.sum().item() - nb_correct = ((result==input).long() * ar_mask).sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask + ) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + ###################################################################### import picoclvr diff --git a/world.py b/world.py index 12c6553..1d64fa3 100755 --- a/world.py +++ b/world.py @@ -61,12 +61,13 @@ class SignSTE(nn.Module): else: return s + class DiscreteSampler2d(nn.Module): def __init__(self): super().__init__() def forward(self, x): - s = (x >= x.max(-3,keepdim=True).values).float() + s = (x >= x.max(-3, keepdim=True).values).float() if self.training: u = x.softmax(dim=-3) @@ -96,7 +97,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - mu, std = train_input.float().mean(), train_input.float().std() def encoder_core(depth, dim): @@ -459,7 +459,8 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - 25000, 1000, + 25000, + 1000, nb_epochs=5, mode="first_last", nb_steps=20, -- 2.20.1 From ead4b8e4edd29578c01501d168e416b47fa4047b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 17:26:17 +0200 Subject: [PATCH 06/16] Update. --- main.py | 3 +- tasks.py | 156 +++++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 132 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index 3be3d55..213524e 100755 --- a/main.py +++ b/main.py @@ -266,7 +266,8 @@ picoclvr_pruner_eval = ( if args.task == "sandbox": task = tasks.SandBox( - tasks.ProblemByheart(), + tasks.ProblemLevel1(), + # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/tasks.py b/tasks.py index fb85576..332d6c5 100755 --- a/tasks.py +++ b/tasks.py @@ -67,13 +67,15 @@ class Problem: def generate_sequences(self, nb): pass - def log_performance(self, sequences, logger): - pass + def seq2str(self, seq): + return "[NOT IMPLEMENTED]" + + +#################### -class ProblemByheart(Problem): - def __init__(self): - nb_seq, len_prompt, len_result = 100, 5, 5 +class ProblemLevel0(Problem): + def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 @@ -83,20 +85,104 @@ class ProblemByheart(Problem): ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask - # problems = [ProblemByheart()] - # nb_common_codes = 100 - # def generate_sequences(nb_samples): - # problem_indexes = torch.randint(len(problems), (nb_samples,)) - # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - # print(f"{nb_samples_per_problem}") - # all_seq = [] - # for nb, p in zip(nb_samples_per_problem, problems): - # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - # return all_seq +class ProblemLevel1(Problem): + def __init__(self, nb_operators=100, len_prompt=5, len_result=8): + self.len_prompt = len_prompt + self.len_result = len_result + self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 + self.operators = F.one_hot( + torch.rand(nb_operators, len_result, len_prompt).argmax(-1), + num_classes=len_prompt, + ) + + def generate_sequences(self, nb): + a = self.len_nb_operator + b = a + 1 + self.len_prompt + sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) + nb_operators = torch.randint(self.operators.size(0), (nb,)) + sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10 + sequences[:, a] = 10 + sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) + sequences[:, b] = 11 + + o = self.operators[nb_operators] + p = sequences[:, a + 1 : b] + print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}") + sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.id2char[x.item()] for x in seq) - # for strain, stest in zip(train_seq, test_seq): - # s = torch.cat((strain, stest), 0) + +#################### + + +class ProblemAddition(Problem): + def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False): + self.nb_digits = nb_digits + self.zero_padded = zero_padded + self.inverted_result = inverted_result + self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")]) + self.id2char = dict([(n, c) for c, n in self.char2id.items()]) + + def tensorize(self, strings): + len_max = max([len(x) for x in strings]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "$" * (len_max - len(s))] + for s in strings + ] + ) + ], + 0, + ) + + def generate_sequences(self, nb): + sequences = [] + for k in range(nb): + a, b = torch.randint(10**self.nb_digits, (2,)) + c = a + b + a, b, c = str(a.item()), str(b.item()), str(c.item()) + if self.zero_padded: + a = "0" * (self.nb_digits - len(a)) + a + b = "0" * (self.nb_digits - len(b)) + b + c = "0" * (self.nb_digits + 1 - len(c)) + c + if self.inverted_result: + c = c[::-1] + sequences.append(f"{a}+{b}={c}$") + + sequences = self.tensorize(sequences) + ar_mask = (sequences == self.char2id["="]).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join(self.id2char[x.item()] for x in seq) + + +# class ProblemUnion(Problem): +# problems = [ProblemByheart()] +# nb_common_codes = 100 + +# def generate_sequences(nb_samples): +# problem_indexes = torch.randint(len(problems), (nb_samples,)) +# nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) +# print(f"{nb_samples_per_problem}") +# all_seq = [] +# for nb, p in zip(nb_samples_per_problem, problems): +# all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) +# return all_seq + +# for strain, stest in zip(train_seq, test_seq): +# s = torch.cat((strain, stest), 0) + +#################### class SandBox(Task): @@ -114,11 +200,21 @@ class SandBox(Task): self.batch_size = batch_size self.device = device + self.problem = problem - self.train_input, self.train_ar_mask = problem.generate_sequences( + self.train_input, self.train_ar_mask = self.problem.generate_sequences( nb_train_samples ) - self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) + self.test_input, self.test_ar_mask = self.problem.generate_sequences( + nb_test_samples + ) + + self.train_input, self.train_ar_mask = self.train_input.to( + device + ), self.train_ar_mask.to(device) + self.test_input, self.test_ar_mask = self.test_input.to( + device + ), self.test_ar_mask.to(device) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -147,10 +243,12 @@ class SandBox(Task): return self.nb_codes def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis + self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 ): - def compute_accuracy(input, ar_mask): + def compute_accuracy(input, ar_mask, logger=None): + input, ar_mask = input[:nmax], ar_mask[:nmax] result = input.clone() * (1 - ar_mask) + masked_inplace_autoregression( model, self.batch_size, @@ -161,6 +259,15 @@ class SandBox(Task): device=self.device, ) + if logger is not None: + for sp, st in zip(result[:10], input[:10]): + logger( + f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}" + ) + logger( + f" {n_epoch} ground truth {self.problem.seq2str(st)}" + ) + nb_total = ar_mask.sum().item() nb_correct = ((result == input).long() * ar_mask).sum().item() @@ -175,7 +282,7 @@ class SandBox(Task): ) test_nb_total, test_nb_correct = compute_accuracy( - self.test_input, self.test_ar_mask + self.test_input, self.test_ar_mask, logger ) logger( @@ -1119,8 +1226,6 @@ class World(Task): device_storage=device_storage, ) - print(f"{train_action_seq.size()=}") - train_frame_seq = self.frame2seq(train_frames).to(device_storage) test_frame_seq = self.frame2seq(test_frames).to(device_storage) @@ -1132,7 +1237,7 @@ class World(Task): self.nb_codes = nb_frame_codes + nb_action_codes train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1) - print(f"{train_action_seq.device=} {nb_frame_codes.device=}") + train_action_seq += nb_frame_codes self.train_input = torch.cat( (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1 @@ -1191,7 +1296,6 @@ class World(Task): (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1 ) result = result.reshape(-1, result.size(-1)) - print(f"{result.size()=}") frames = self.seq2frame(result) image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") -- 2.20.1 From 1be1638f9906a1071dc82ebc6f35f8fc0eb91a3d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:03:24 +0200 Subject: [PATCH 07/16] Update. --- tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 332d6c5..5ac78cb 100755 --- a/tasks.py +++ b/tasks.py @@ -101,7 +101,7 @@ class ProblemLevel1(Problem): b = a + 1 + self.len_prompt sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) nb_operators = torch.randint(self.operators.size(0), (nb,)) - sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10 + sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10 sequences[:, a] = 10 sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) sequences[:, b] = 11 @@ -115,7 +115,7 @@ class ProblemLevel1(Problem): return sequences, ar_mask def seq2str(self, seq): - return "".join(self.id2char[x.item()] for x in seq) + return "".join("0123456789|>"[x.item()] for x in seq) #################### -- 2.20.1 From 8d2ebe29b48e3cf2f0a3937ab1e44d0e12a4924e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:17:30 +0200 Subject: [PATCH 08/16] Update. --- tasks.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tasks.py b/tasks.py index 5ac78cb..706e1d9 100755 --- a/tasks.py +++ b/tasks.py @@ -87,29 +87,28 @@ class ProblemLevel0(Problem): class ProblemLevel1(Problem): - def __init__(self, nb_operators=100, len_prompt=5, len_result=8): - self.len_prompt = len_prompt + def __init__(self, nb_operators=100, len_source=5, len_result=8): + self.len_source = len_source self.len_result = len_result self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 self.operators = F.one_hot( - torch.rand(nb_operators, len_result, len_prompt).argmax(-1), - num_classes=len_prompt, + torch.rand(nb_operators, len_result, len_source).argmax(-1), + num_classes=len_source, ) + + def generate_sequences(self, nb): - a = self.len_nb_operator - b = a + 1 + self.len_prompt - sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) nb_operators = torch.randint(self.operators.size(0), (nb,)) - sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10 - sequences[:, a] = 10 - sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) - sequences[:, b] = 11 - - o = self.operators[nb_operators] - p = sequences[:, a + 1 : b] - print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}") - sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1) + operators = self.operators[nb_operators] + nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10 + marker1 = torch.full((nb,1),10) + source = torch.randint(10, (nb, self.len_source)) + marker2 = torch.full((nb,1),11) + result = operators.bmm(source[:, :, None]).squeeze(-1) + print(f"{nb_operators.dtype=} {marker1.dtype=}") + sequences = torch.cat((nb_operators, marker1, source,marker2,result),1) + print(f"{sequences.size()=}") ar_mask = (sequences == 11).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask -- 2.20.1 From a3211f96c7426a613b82a2de87d4dd70640e8f46 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:23:03 +0200 Subject: [PATCH 09/16] Update. --- main.py | 2 +- tasks.py | 42 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 213524e..e3fd9f0 100755 --- a/main.py +++ b/main.py @@ -266,7 +266,7 @@ picoclvr_pruner_eval = ( if args.task == "sandbox": task = tasks.SandBox( - tasks.ProblemLevel1(), + tasks.ProblemLevel2(), # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index 706e1d9..73f61bf 100755 --- a/tasks.py +++ b/tasks.py @@ -96,18 +96,19 @@ class ProblemLevel1(Problem): num_classes=len_source, ) - - def generate_sequences(self, nb): nb_operators = torch.randint(self.operators.size(0), (nb,)) operators = self.operators[nb_operators] - nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10 - marker1 = torch.full((nb,1),10) + nb_operators = ( + nb_operators[:, None] + // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) + ) % 10 + marker1 = torch.full((nb, 1), 10) source = torch.randint(10, (nb, self.len_source)) - marker2 = torch.full((nb,1),11) + marker2 = torch.full((nb, 1), 11) result = operators.bmm(source[:, :, None]).squeeze(-1) print(f"{nb_operators.dtype=} {marker1.dtype=}") - sequences = torch.cat((nb_operators, marker1, source,marker2,result),1) + sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) print(f"{sequences.size()=}") ar_mask = (sequences == 11).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) @@ -117,6 +118,35 @@ class ProblemLevel1(Problem): return "".join("0123456789|>"[x.item()] for x in seq) +class ProblemLevel2(Problem): + def __init__(self, len_source=5, len_result=8): + self.len_source = len_source + self.len_result = len_result + + def generate_sequences(self, nb): + operators = F.one_hot( + torch.rand(nb, self.len_result, self.len_source).argmax(-1), + num_classes=self.len_source, + ) + source1 = torch.randint(10, (nb, self.len_source)) + marker1 = torch.full((nb, 1), 10) + result1 = operators.bmm(source1[:, :, None]).squeeze(-1) + marker2 = torch.full((nb, 1), 11) + source2 = torch.randint(10, (nb, self.len_source)) + marker3 = torch.full((nb, 1), 12) + result2 = operators.bmm(source2[:, :, None]).squeeze(-1) + + sequences = torch.cat( + (source1, marker1, result1, marker2, source2, marker3, result2), 1 + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789>|~"[x.item()] for x in seq) + + #################### -- 2.20.1 From d6f73f1d5093fb098e822e14db382dd3a1c63a2a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:35:59 +0200 Subject: [PATCH 10/16] Update. --- main.py | 33 ++++++++++++++++++++++++++++++++- tasks.py | 2 +- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index e3fd9f0..0d4930d 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,17 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options +parser.add_argument("--sandbox_level", type=int, default=0) + +parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) + +parser.add_argument("--sandbox_levels_len_source", type=int, default=5) + +parser.add_argument("--sandbox_levels_len_result", type=int, default=8) + +############################## +# picoclvr options + parser.add_argument("--picoclvr_nb_colors", type=int, default=5) parser.add_argument("--picoclvr_height", type=int, default=12) @@ -265,8 +276,28 @@ picoclvr_pruner_eval = ( ###################################################################### if args.task == "sandbox": + if args.sandbox_level == 0: + problem = tasks.ProblemLevel0( + nb_sentences=args.sandbox_levels_nb_items, + len_prompt=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 1: + problem = tasks.ProblemLevel1( + nb_operators=args.sandbox_levels_nb_items, + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 2: + problem = tasks.ProblemLevel2( + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + else: + raise ValueError(f"Unknown sandbox level {args.sandbox_level}") + task = tasks.SandBox( - tasks.ProblemLevel2(), + problem, # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index 73f61bf..e7c2f75 100755 --- a/tasks.py +++ b/tasks.py @@ -76,7 +76,7 @@ class Problem: class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): - self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): -- 2.20.1 From e781d77071fa26f393f50451f91c70f4a0850ca5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 00:59:24 +0200 Subject: [PATCH 11/16] Update. --- main.py | 6 +++--- tasks.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 0d4930d..19f918c 100755 --- a/main.py +++ b/main.py @@ -86,7 +86,7 @@ parser.add_argument("--sandbox_level", type=int, default=0) parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) -parser.add_argument("--sandbox_levels_len_source", type=int, default=5) +parser.add_argument("--sandbox_levels_len_source", type=int, default=6) parser.add_argument("--sandbox_levels_len_result", type=int, default=8) @@ -163,9 +163,9 @@ if args.result_dir is None: default_args = { "sandbox": { - "nb_epochs": 10, + "nb_epochs": 50, "batch_size": 25, - "nb_train_samples": 25000, + "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { diff --git a/tasks.py b/tasks.py index e7c2f75..c5418b4 100755 --- a/tasks.py +++ b/tasks.py @@ -104,7 +104,8 @@ class ProblemLevel1(Problem): // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) ) % 10 marker1 = torch.full((nb, 1), 10) - source = torch.randint(10, (nb, self.len_source)) + # source = torch.randint(10, (nb, self.len_source)) + source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] marker2 = torch.full((nb, 1), 11) result = operators.bmm(source[:, :, None]).squeeze(-1) print(f"{nb_operators.dtype=} {marker1.dtype=}") @@ -128,7 +129,8 @@ class ProblemLevel2(Problem): torch.rand(nb, self.len_result, self.len_source).argmax(-1), num_classes=self.len_source, ) - source1 = torch.randint(10, (nb, self.len_source)) + source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + # source1 = torch.randint(10, (nb, self.len_source)) marker1 = torch.full((nb, 1), 10) result1 = operators.bmm(source1[:, :, None]).squeeze(-1) marker2 = torch.full((nb, 1), 11) -- 2.20.1 From d6ce421535abe92b66811d3a91f8ba53cd8632a1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 07:24:41 +0200 Subject: [PATCH 12/16] Update. --- main.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 19f918c..efcc0dd 100755 --- a/main.py +++ b/main.py @@ -59,15 +59,17 @@ parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") -parser.add_argument("--dim_model", type=int, default=512) +parser.add_argument("--model", type=str, default="37M") -parser.add_argument("--dim_keys", type=int, default=64) +parser.add_argument("--dim_model", type=int, default=None) -parser.add_argument("--dim_hidden", type=int, default=2048) +parser.add_argument("--dim_keys", type=int, default=None) -parser.add_argument("--nb_heads", type=int, default=8) +parser.add_argument("--dim_hidden", type=int, default=None) -parser.add_argument("--nb_blocks", type=int, default=12) +parser.add_argument("--nb_heads", type=int, default=None) + +parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) @@ -161,7 +163,7 @@ if args.result_dir is None: ###################################################################### -default_args = { +default_task_args = { "sandbox": { "nb_epochs": 50, "batch_size": 25, @@ -212,13 +214,53 @@ default_args = { }, } -if args.task in default_args: - for k, v in default_args[args.task].items(): +if args.task in default_task_args: + for k, v in default_task_args[args.task].items(): if getattr(args, k) is None: setattr(args, k, v) ###################################################################### +default_model_args = { + "17K": { + "dim_model": 32, + "dim_keys": 32, + "dim_hidden": 32, + "nb_heads": 2, + "nb_blocks": 2, + }, + "37M": { + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 12, + }, + "122M": { + "dim_model": 768, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 24, + }, + "352M": { + "dim_model": 1024, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 48, + }, +} + +if args.model in default_model_args: + for k, v in default_model_args[args.model].items(): + if getattr(args, k) is None: + setattr(args, k, v) +else: + raise ValueError(f"Unknown model {args.model}") + +###################################################################### + try: os.mkdir(args.result_dir) except FileExistsError: -- 2.20.1 From c9dbc3abf436df8af1379d04ab51159e821496f1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 13:54:59 +0200 Subject: [PATCH 13/16] Update. --- main.py | 16 ++++++- rpl.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ tasks.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 1 deletion(-) create mode 100755 rpl.py diff --git a/main.py b/main.py index efcc0dd..63e6668 100755 --- a/main.py +++ b/main.py @@ -36,7 +36,7 @@ parser.add_argument( "--task", type=str, default="sandbox", - help="sandbox, picoclvr, mnist, maze, snake, stack, expr, world", + help="sandbox, picoclvr, mnist, maze, snake, stack, expr, rpl, world", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -206,6 +206,12 @@ default_task_args = { "nb_train_samples": 1000000, "nb_test_samples": 10000, }, + "rpl": { + "nb_epochs": 40, + "batch_size": 25, + "nb_train_samples": 1000000, + "nb_test_samples": 10000, + }, "world": { "nb_epochs": 10, "batch_size": 25, @@ -419,6 +425,14 @@ elif args.task == "expr": device=device, ) +elif args.task == "rpl": + task = tasks.RPL( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + device=device, + ) + elif args.task == "world": task = tasks.World( nb_train_samples=args.nb_train_samples, diff --git a/rpl.py b/rpl.py new file mode 100755 index 0000000..42db38c --- /dev/null +++ b/rpl.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +def rpl_exec(program, stack): + for op in program: + if op == "add": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a + b) + elif op == "min": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(min(a, b)) + elif op == "max": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(max(a, b)) + elif op == "swp": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a) + stack.append(b) + elif op == "rep": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack += [b] * a + elif op == "dup": + if len(stack) > 0: + a = stack.pop() + stack.append(a) + stack.append(a) + elif op == "del": + if len(stack) > 0: + a = stack.pop() + else: + raise ValueError(f"Unknown instruction {op}") + + +rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] + +###################################################################### + + +def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): + prog_len = 1 + torch.randint(prog_len - 1, (1,)).item() + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] + result = result + [""] + stack + rpl_exec(prog, stack) + result = result + [""] + stack + + result = result + [""] + prog + result = result + [""] + return result + + +def next_marker(seq, tokens, start=0): + pos = None + for t in tokens: + try: + i = seq.index(t, start) + if pos is None or i < pos: + pos = i + except ValueError: + pass + return pos + + +def check(seq): + io = [] + k = 0 + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + e = next_marker(seq, ["", ""], start=o) + if o is None or e is None: + raise ValueError("Invalid input/output") + io.append((seq[k + 1 : o], seq[o + 1 : e])) + k = e + + if seq[k] == "": + e = next_marker(seq, [""], start=k) + if e is None: + prog = [] + else: + prog = seq[k + 1 : e] + + nb_total, nb_errors = 0, 0 + + if len(set(prog) - set(rpl_ops)) > 0: + for stack, target_stack in io: + nb_total += len(target_stack) + nb_errors += len(target_stack) + + else: + for stack, target_stack in io: + # print(f"INIT {stack} PROG {prog}") + rpl_exec(prog, stack) + # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + nb_total += len(target_stack) + nb_errors += abs(len(stack) - len(target_stack)) + nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + + return nb_total, nb_errors + + +###################################################################### + +if __name__ == "__main__": + seq = generate() + print(seq) + seq[3] = 7 + print(seq) + print(check(seq)) diff --git a/tasks.py b/tasks.py index c5418b4..a3d47f5 100755 --- a/tasks.py +++ b/tasks.py @@ -1021,6 +1021,124 @@ class Stack(Task): ############################################################## +###################################################################### + +import rpl + + +class RPL(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [ + self.token2id[str(c)] + for c in s + [""] * (len_max - len(s)) + ] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + + train_sequences = [ + rpl.generate() + for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data") + ] + test_sequences = [ + rpl.generate() for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") + ] + + symbols = list( + set([""] + [x for l in train_sequences + test_sequences for x in l]) + ) + val_max = max([x if type(x) is int else 0 for x in symbols]) + symbols = list(filter(lambda x: type(x) is str, symbols)) + symbols.sort() + symbols += [str(n) for n in range(val_max + 1)] + print(f"{val_max=}") + self.token2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2token = dict([(n, c) for c, n in self.token2id.items()]) + + self.t_nul, self.t_prog = self.token2id[""], self.token2id[""] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.t_nul).max(0).values.nonzero().max() + 3 + batch = batch[:, :last] + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_errors(input, nb_to_log=0): + result = input.clone() + s = (result == self.t_prog).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + if nb_to_log > 0: + for x in result[:nb_to_log]: + s = " ".join([self.id2token[i.item()] for i in x]) + logger(f"check {n_epoch} {s}") + nb_to_log -= min(nb_to_log, result.size(0)) + + sum_nb_total, sum_nb_errors = 0, 0 + for x in result: + seq = [self.id2token[i.item()] for i in x] + nb_total, nb_errors = rpl.check(seq) + sum_nb_total += nb_total + sum_nb_errors += nb_errors + + return sum_nb_total, sum_nb_errors + + test_nb_total, test_nb_errors = compute_nb_errors(self.test_input, nb_to_log=10) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + ###################################################################### -- 2.20.1 From 0c47d4d8ef8c4938f4765af816349cf30da14cb1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 15:31:28 +0200 Subject: [PATCH 14/16] Update. --- rpl.py | 47 +++++++++++++++++++++++++++++++++-------------- tasks.py | 38 ++++++++++++++++++++++++++------------ 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/rpl.py b/rpl.py index 42db38c..155bc69 100755 --- a/rpl.py +++ b/rpl.py @@ -11,6 +11,7 @@ from torch.nn import functional as F def rpl_exec(program, stack): + stack = stack.copy() for op in program: if op == "add": if len(stack) > 1: @@ -44,6 +45,8 @@ def rpl_exec(program, stack): else: raise ValueError(f"Unknown instruction {op}") + return stack + rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] @@ -57,9 +60,8 @@ def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): result = [] for _ in range(nb_runs): stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] - result = result + [""] + stack - rpl_exec(prog, stack) - result = result + [""] + stack + result_stack = rpl_exec(prog, stack) + result = result + [""] + stack + [""] + result_stack result = result + [""] + prog result = result + [""] @@ -78,7 +80,7 @@ def next_marker(seq, tokens, start=0): return pos -def check(seq): +def decompose(seq): io = [] k = 0 while seq[k] == "": @@ -86,7 +88,13 @@ def check(seq): e = next_marker(seq, ["", ""], start=o) if o is None or e is None: raise ValueError("Invalid input/output") - io.append((seq[k + 1 : o], seq[o + 1 : e])) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError("Invalid input/output") + k = e if seq[k] == "": @@ -95,24 +103,35 @@ def check(seq): prog = [] else: prog = seq[k + 1 : e] + return prog, io + + +def compute_nb_errors(seq): + prog, io = decompose(seq) nb_total, nb_errors = 0, 0 + stacks = [] + if len(set(prog) - set(rpl_ops)) > 0: - for stack, target_stack in io: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, "N/A", False)) nb_total += len(target_stack) nb_errors += len(target_stack) else: - for stack, target_stack in io: - # print(f"INIT {stack} PROG {prog}") - rpl_exec(prog, stack) - # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) nb_total += len(target_stack) - nb_errors += abs(len(stack) - len(target_stack)) - nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + e = abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) - return nb_total, nb_errors + return nb_total, nb_errors, prog, stacks ###################################################################### @@ -122,4 +141,4 @@ if __name__ == "__main__": print(seq) seq[3] = 7 print(seq) - print(check(seq)) + print(compute_nb_errors(seq)) diff --git a/tasks.py b/tasks.py index a3d47f5..75cd35e 100755 --- a/tasks.py +++ b/tasks.py @@ -1044,6 +1044,9 @@ class RPL(Task): 0, ).to(self.device) + def seq2str(self, seq): + return " ".join([self.id2token[i] for i in seq]) + def __init__( self, nb_train_samples, @@ -1117,22 +1120,33 @@ class RPL(Task): device=self.device, ) - if nb_to_log > 0: - for x in result[:nb_to_log]: - s = " ".join([self.id2token[i.item()] for i in x]) - logger(f"check {n_epoch} {s}") - nb_to_log -= min(nb_to_log, result.size(0)) - sum_nb_total, sum_nb_errors = 0, 0 - for x in result: - seq = [self.id2token[i.item()] for i in x] - nb_total, nb_errors = rpl.check(seq) - sum_nb_total += nb_total - sum_nb_errors += nb_errors + for x, y in zip(input, result): + seq = [self.id2token[i.item()] for i in y] + nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) + sum_nb_total += 1 + sum_nb_errors += 0 if nb_errors == 0 else 1 + if nb_to_log > 0: + gt_seq = [self.id2token[i.item()] for i in x] + _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) + gt_prog = " ".join([str(x) for x in gt_prog]) + prog = " ".join([str(x) for x in prog]) + logger(f"GROUND-TRUTH PROG [{gt_prog}] PREDICTED PROG [{prog}]") + for start_stack, target_stack, result_stack, correct in stacks: + comment = " CORRECT" if correct else "" + start_stack = " ".join([str(x) for x in start_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + result_stack = " ".join([str(x) for x in result_stack]) + logger( + f" [{start_stack}] -> [{result_stack}] TARGET [{target_stack}]{comment}" + ) + nb_to_log -= 1 return sum_nb_total, sum_nb_errors - test_nb_total, test_nb_errors = compute_nb_errors(self.test_input, nb_to_log=10) + test_nb_total, test_nb_errors = compute_nb_errors( + self.test_input[:1000], nb_to_log=10 + ) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" -- 2.20.1 From 5703df4c32a0856c8fa4b1ff97810cdc1fb76253 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 15:43:01 +0200 Subject: [PATCH 15/16] Update. --- main.py | 2 +- rpl.py | 6 +++--- tasks.py | 24 ++++++++++++++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 63e6668..d1f82cf 100755 --- a/main.py +++ b/main.py @@ -209,7 +209,7 @@ default_task_args = { "rpl": { "nb_epochs": 40, "batch_size": 25, - "nb_train_samples": 1000000, + "nb_train_samples": 100000, "nb_test_samples": 10000, }, "world": { diff --git a/rpl.py b/rpl.py index 155bc69..7f7dcfc 100755 --- a/rpl.py +++ b/rpl.py @@ -53,13 +53,13 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): - prog_len = 1 + torch.randint(prog_len - 1, (1,)).item() +def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): + prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] result = [] for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] + stack = [x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))] result_stack = rpl_exec(prog, stack) result = result + [""] + stack + [""] + result_stack diff --git a/tasks.py b/tasks.py index 75cd35e..e14ceb7 100755 --- a/tasks.py +++ b/tasks.py @@ -1052,6 +1052,10 @@ class RPL(Task): nb_train_samples, nb_test_samples, batch_size, + nb_starting_values=3, + max_input=9, + prog_len=6, + nb_runs=5, device=torch.device("cpu"), ): super().__init__() @@ -1060,11 +1064,23 @@ class RPL(Task): self.device = device train_sequences = [ - rpl.generate() + rpl.generate( + nb_starting_values=nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data") ] + test_sequences = [ - rpl.generate() for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") + rpl.generate( + nb_starting_values=nb_starting_values, + max_input=max_input, + prog_len=prog_len, + nb_runs=nb_runs, + ) + for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") ] symbols = list( @@ -1131,14 +1147,14 @@ class RPL(Task): _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) gt_prog = " ".join([str(x) for x in gt_prog]) prog = " ".join([str(x) for x in prog]) - logger(f"GROUND-TRUTH PROG [{gt_prog}] PREDICTED PROG [{prog}]") + logger(f"PROG [{gt_prog}] PREDICTED [{prog}]") for start_stack, target_stack, result_stack, correct in stacks: comment = " CORRECT" if correct else "" start_stack = " ".join([str(x) for x in start_stack]) target_stack = " ".join([str(x) for x in target_stack]) result_stack = " ".join([str(x) for x in result_stack]) logger( - f" [{start_stack}] -> [{result_stack}] TARGET [{target_stack}]{comment}" + f" [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]{comment}" ) nb_to_log -= 1 -- 2.20.1 From 439c597d409c344283f8996f042daf79d3f24de2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 16:14:50 +0200 Subject: [PATCH 16/16] Update. --- main.py | 1 + rpl.py | 28 +++++++++++++++++++--------- tasks.py | 15 ++++++++++++--- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index d1f82cf..901b1d0 100755 --- a/main.py +++ b/main.py @@ -430,6 +430,7 @@ elif args.task == "rpl": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + logger=log_string, device=device, ) diff --git a/rpl.py b/rpl.py index 7f7dcfc..7e865a5 100755 --- a/rpl.py +++ b/rpl.py @@ -55,16 +55,26 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() - prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] - result = [] - for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))] - result_stack = rpl_exec(prog, stack) - result = result + [""] + stack + [""] + result_stack + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + if no_empty_stack: + break - result = result + [""] + prog - result = result + [""] return result @@ -116,7 +126,7 @@ def compute_nb_errors(seq): if len(set(prog) - set(rpl_ops)) > 0: # Program is not valid, we count 100% error for start_stack, target_stack in io: - stacks.append((start_stack, target_stack, "N/A", False)) + stacks.append((start_stack, target_stack, ["N/A"], False)) nb_total += len(target_stack) nb_errors += len(target_stack) diff --git a/tasks.py b/tasks.py index e14ceb7..0f44760 100755 --- a/tasks.py +++ b/tasks.py @@ -1056,6 +1056,7 @@ class RPL(Task): max_input=9, prog_len=6, nb_runs=5, + logger=None, device=torch.device("cpu"), ): super().__init__() @@ -1099,6 +1100,13 @@ class RPL(Task): self.train_input = self.tensorize(train_sequences) self.test_input = self.tensorize(test_sequences) + if logger is not None: + for x in self.train_input[:10]: + end = (x != self.t_nul).nonzero().max().item() + 1 + seq = [self.id2token[i.item()] for i in x[:end]] + s = " ".join(seq) + logger(f"example_seq {s}") + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): @@ -1147,14 +1155,15 @@ class RPL(Task): _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) gt_prog = " ".join([str(x) for x in gt_prog]) prog = " ".join([str(x) for x in prog]) - logger(f"PROG [{gt_prog}] PREDICTED [{prog}]") + comment = "*" if nb_errors == 0 else "-" + logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]") for start_stack, target_stack, result_stack, correct in stacks: - comment = " CORRECT" if correct else "" + comment = "*" if correct else "-" start_stack = " ".join([str(x) for x in start_stack]) target_stack = " ".join([str(x) for x in target_stack]) result_stack = " ".join([str(x) for x in result_stack]) logger( - f" [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]{comment}" + f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]" ) nb_to_log -= 1 -- 2.20.1