From 439c597d409c344283f8996f042daf79d3f24de2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 16:14:50 +0200 Subject: [PATCH] Update. --- main.py | 1 + rpl.py | 28 +++++++++++++++++++--------- tasks.py | 15 ++++++++++++--- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index d1f82cf..901b1d0 100755 --- a/main.py +++ b/main.py @@ -430,6 +430,7 @@ elif args.task == "rpl": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + logger=log_string, device=device, ) diff --git a/rpl.py b/rpl.py index 7f7dcfc..7e865a5 100755 --- a/rpl.py +++ b/rpl.py @@ -55,16 +55,26 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() - prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] - result = [] - for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))] - result_stack = rpl_exec(prog, stack) - result = result + [""] + stack + [""] + result_stack + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + if no_empty_stack: + break - result = result + [""] + prog - result = result + [""] return result @@ -116,7 +126,7 @@ def compute_nb_errors(seq): if len(set(prog) - set(rpl_ops)) > 0: # Program is not valid, we count 100% error for start_stack, target_stack in io: - stacks.append((start_stack, target_stack, "N/A", False)) + stacks.append((start_stack, target_stack, ["N/A"], False)) nb_total += len(target_stack) nb_errors += len(target_stack) diff --git a/tasks.py b/tasks.py index e14ceb7..0f44760 100755 --- a/tasks.py +++ b/tasks.py @@ -1056,6 +1056,7 @@ class RPL(Task): max_input=9, prog_len=6, nb_runs=5, + logger=None, device=torch.device("cpu"), ): super().__init__() @@ -1099,6 +1100,13 @@ class RPL(Task): self.train_input = self.tensorize(train_sequences) self.test_input = self.tensorize(test_sequences) + if logger is not None: + for x in self.train_input[:10]: + end = (x != self.t_nul).nonzero().max().item() + 1 + seq = [self.id2token[i.item()] for i in x[:end]] + s = " ".join(seq) + logger(f"example_seq {s}") + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): @@ -1147,14 +1155,15 @@ class RPL(Task): _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) gt_prog = " ".join([str(x) for x in gt_prog]) prog = " ".join([str(x) for x in prog]) - logger(f"PROG [{gt_prog}] PREDICTED [{prog}]") + comment = "*" if nb_errors == 0 else "-" + logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]") for start_stack, target_stack, result_stack, correct in stacks: - comment = " CORRECT" if correct else "" + comment = "*" if correct else "-" start_stack = " ".join([str(x) for x in start_stack]) target_stack = " ".join([str(x) for x in target_stack]) result_stack = " ".join([str(x) for x in result_stack]) logger( - f" [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]{comment}" + f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]" ) nb_to_log -= 1 -- 2.39.5