X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=rpl.py;h=155bc69cc76b4b88a3bd05c75cb2ce977e762df9;hb=0c47d4d8ef8c4938f4765af816349cf30da14cb1;hp=42db38cad4c430f38e8cf38f3872b82e91aa31ff;hpb=c9dbc3abf436df8af1379d04ab51159e821496f1;p=picoclvr.git diff --git a/rpl.py b/rpl.py index 42db38c..155bc69 100755 --- a/rpl.py +++ b/rpl.py @@ -11,6 +11,7 @@ from torch.nn import functional as F def rpl_exec(program, stack): + stack = stack.copy() for op in program: if op == "add": if len(stack) > 1: @@ -44,6 +45,8 @@ def rpl_exec(program, stack): else: raise ValueError(f"Unknown instruction {op}") + return stack + rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] @@ -57,9 +60,8 @@ def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): result = [] for _ in range(nb_runs): stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] - result = result + [""] + stack - rpl_exec(prog, stack) - result = result + [""] + stack + result_stack = rpl_exec(prog, stack) + result = result + [""] + stack + [""] + result_stack result = result + [""] + prog result = result + [""] @@ -78,7 +80,7 @@ def next_marker(seq, tokens, start=0): return pos -def check(seq): +def decompose(seq): io = [] k = 0 while seq[k] == "": @@ -86,7 +88,13 @@ def check(seq): e = next_marker(seq, ["", ""], start=o) if o is None or e is None: raise ValueError("Invalid input/output") - io.append((seq[k + 1 : o], seq[o + 1 : e])) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError("Invalid input/output") + k = e if seq[k] == "": @@ -95,24 +103,35 @@ def check(seq): prog = [] else: prog = seq[k + 1 : e] + return prog, io + + +def compute_nb_errors(seq): + prog, io = decompose(seq) nb_total, nb_errors = 0, 0 + stacks = [] + if len(set(prog) - set(rpl_ops)) > 0: - for stack, target_stack in io: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, "N/A", False)) nb_total += len(target_stack) nb_errors += len(target_stack) else: - for stack, target_stack in io: - # print(f"INIT {stack} PROG {prog}") - rpl_exec(prog, stack) - # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) nb_total += len(target_stack) - nb_errors += abs(len(stack) - len(target_stack)) - nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + e = abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) - return nb_total, nb_errors + return nb_total, nb_errors, prog, stacks ###################################################################### @@ -122,4 +141,4 @@ if __name__ == "__main__": print(seq) seq[3] = 7 print(seq) - print(check(seq)) + print(compute_nb_errors(seq))