From 0c47d4d8ef8c4938f4765af816349cf30da14cb1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 15:31:28 +0200 Subject: [PATCH] Update. --- rpl.py | 47 +++++++++++++++++++++++++++++++++-------------- tasks.py | 38 ++++++++++++++++++++++++++------------ 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/rpl.py b/rpl.py index 42db38c..155bc69 100755 --- a/rpl.py +++ b/rpl.py @@ -11,6 +11,7 @@ from torch.nn import functional as F def rpl_exec(program, stack): + stack = stack.copy() for op in program: if op == "add": if len(stack) > 1: @@ -44,6 +45,8 @@ def rpl_exec(program, stack): else: raise ValueError(f"Unknown instruction {op}") + return stack + rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] @@ -57,9 +60,8 @@ def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): result = [] for _ in range(nb_runs): stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] - result = result + [""] + stack - rpl_exec(prog, stack) - result = result + [""] + stack + result_stack = rpl_exec(prog, stack) + result = result + [""] + stack + [""] + result_stack result = result + [""] + prog result = result + [""] @@ -78,7 +80,7 @@ def next_marker(seq, tokens, start=0): return pos -def check(seq): +def decompose(seq): io = [] k = 0 while seq[k] == "": @@ -86,7 +88,13 @@ def check(seq): e = next_marker(seq, ["", ""], start=o) if o is None or e is None: raise ValueError("Invalid input/output") - io.append((seq[k + 1 : o], seq[o + 1 : e])) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError("Invalid input/output") + k = e if seq[k] == "": @@ -95,24 +103,35 @@ def check(seq): prog = [] else: prog = seq[k + 1 : e] + return prog, io + + +def compute_nb_errors(seq): + prog, io = decompose(seq) nb_total, nb_errors = 0, 0 + stacks = [] + if len(set(prog) - set(rpl_ops)) > 0: - for stack, target_stack in io: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, "N/A", False)) nb_total += len(target_stack) nb_errors += len(target_stack) else: - for stack, target_stack in io: - # print(f"INIT {stack} PROG {prog}") - rpl_exec(prog, stack) - # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) nb_total += len(target_stack) - nb_errors += abs(len(stack) - len(target_stack)) - nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + e = abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) - return nb_total, nb_errors + return nb_total, nb_errors, prog, stacks ###################################################################### @@ -122,4 +141,4 @@ if __name__ == "__main__": print(seq) seq[3] = 7 print(seq) - print(check(seq)) + print(compute_nb_errors(seq)) diff --git a/tasks.py b/tasks.py index a3d47f5..75cd35e 100755 --- a/tasks.py +++ b/tasks.py @@ -1044,6 +1044,9 @@ class RPL(Task): 0, ).to(self.device) + def seq2str(self, seq): + return " ".join([self.id2token[i] for i in seq]) + def __init__( self, nb_train_samples, @@ -1117,22 +1120,33 @@ class RPL(Task): device=self.device, ) - if nb_to_log > 0: - for x in result[:nb_to_log]: - s = " ".join([self.id2token[i.item()] for i in x]) - logger(f"check {n_epoch} {s}") - nb_to_log -= min(nb_to_log, result.size(0)) - sum_nb_total, sum_nb_errors = 0, 0 - for x in result: - seq = [self.id2token[i.item()] for i in x] - nb_total, nb_errors = rpl.check(seq) - sum_nb_total += nb_total - sum_nb_errors += nb_errors + for x, y in zip(input, result): + seq = [self.id2token[i.item()] for i in y] + nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) + sum_nb_total += 1 + sum_nb_errors += 0 if nb_errors == 0 else 1 + if nb_to_log > 0: + gt_seq = [self.id2token[i.item()] for i in x] + _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) + gt_prog = " ".join([str(x) for x in gt_prog]) + prog = " ".join([str(x) for x in prog]) + logger(f"GROUND-TRUTH PROG [{gt_prog}] PREDICTED PROG [{prog}]") + for start_stack, target_stack, result_stack, correct in stacks: + comment = " CORRECT" if correct else "" + start_stack = " ".join([str(x) for x in start_stack]) + target_stack = " ".join([str(x) for x in target_stack]) + result_stack = " ".join([str(x) for x in result_stack]) + logger( + f" [{start_stack}] -> [{result_stack}] TARGET [{target_stack}]{comment}" + ) + nb_to_log -= 1 return sum_nb_total, sum_nb_errors - test_nb_total, test_nb_errors = compute_nb_errors(self.test_input, nb_to_log=10) + test_nb_total, test_nb_errors = compute_nb_errors( + self.test_input[:1000], nb_to_log=10 + ) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" -- 2.39.5