X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=rpl.py;h=b848afa756a6188a60d29a446bdd51d1e455b74a;hb=8855d37cef610b39f37d2b3b331046d1e7040a37;hp=42db38cad4c430f38e8cf38f3872b82e91aa31ff;hpb=c9dbc3abf436df8af1379d04ab51159e821496f1;p=picoclvr.git diff --git a/rpl.py b/rpl.py index 42db38c..b848afa 100755 --- a/rpl.py +++ b/rpl.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -11,6 +16,7 @@ from torch.nn import functional as F def rpl_exec(program, stack): + stack = stack.copy() for op in program: if op == "add": if len(stack) > 1: @@ -44,25 +50,41 @@ def rpl_exec(program, stack): else: raise ValueError(f"Unknown instruction {op}") + return stack + rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): - prog_len = 1 + torch.randint(prog_len - 1, (1,)).item() - prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): + prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() + + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack - result = [] - for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] - result = result + [""] + stack - rpl_exec(prog, stack) - result = result + [""] + stack + result = result + [""] + prog + result = result + [""] + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): + break - result = result + [""] + prog - result = result + [""] return result @@ -78,41 +100,71 @@ def next_marker(seq, tokens, start=0): return pos -def check(seq): +def decompose(seq): io = [] k = 0 - while seq[k] == "": - o = next_marker(seq, [""], start=k + 1) - e = next_marker(seq, ["", ""], start=o) - if o is None or e is None: - raise ValueError("Invalid input/output") - io.append((seq[k + 1 : o], seq[o + 1 : e])) + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + if o is None: + raise ValueError("Missing output markers (should be correct in the prompt)") + e = next_marker(seq, ["", ""], start=o) + if e is None: + raise ValueError( + "Missing input/output markers (should be correct in the prompt)" + ) + try: + io.append( + ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) + ) + except ValueError: + raise ValueError( + "Invalid input/output value (should be correct in the prompt)" + ) + k = e - if seq[k] == "": + if seq[k] == "": e = next_marker(seq, [""], start=k) if e is None: prog = [] else: prog = seq[k + 1 : e] + else: + raise ValueError("Missing (it should be in the prompt)") + + return prog, io + + +def stack_distance(target_stack, result_stack): + return abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + + +def compute_nb_errors(seq): + prog, io = decompose(seq) nb_total, nb_errors = 0, 0 + stacks = [] + if len(set(prog) - set(rpl_ops)) > 0: - for stack, target_stack in io: + # Program is not valid, we count 100% error + for start_stack, target_stack in io: + stacks.append((start_stack, target_stack, ["N/A"], False)) nb_total += len(target_stack) nb_errors += len(target_stack) else: - for stack, target_stack in io: - # print(f"INIT {stack} PROG {prog}") - rpl_exec(prog, stack) - # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + # Program is valid + for start_stack, target_stack in io: + result_stack = rpl_exec(prog, start_stack) nb_total += len(target_stack) - nb_errors += abs(len(stack) - len(target_stack)) - nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + e = stack_distance(target_stack, result_stack) + nb_errors += e + stacks.append((start_stack, target_stack, result_stack, e == 0)) - return nb_total, nb_errors + return nb_total, nb_errors, prog, stacks ###################################################################### @@ -122,4 +174,4 @@ if __name__ == "__main__": print(seq) seq[3] = 7 print(seq) - print(check(seq)) + print(compute_nb_errors(seq))