X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=rpl.py;h=b848afa756a6188a60d29a446bdd51d1e455b74a;hb=b1d28a1ed672be21947509dac2f90666b65b5034;hp=7f7dcfc247e72d5e5592e4933e6b9febb77bb61b;hpb=5703df4c32a0856c8fa4b1ff97810cdc1fb76253;p=picoclvr.git diff --git a/rpl.py b/rpl.py index 7f7dcfc..b848afa 100755 --- a/rpl.py +++ b/rpl.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -53,18 +58,33 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() - prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] - result = [] - for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))] - result_stack = rpl_exec(prog, stack) - result = result + [""] + stack + [""] + result_stack + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): + break - result = result + [""] + prog - result = result + [""] return result @@ -83,29 +103,44 @@ def next_marker(seq, tokens, start=0): def decompose(seq): io = [] k = 0 - while seq[k] == "": - o = next_marker(seq, [""], start=k + 1) - e = next_marker(seq, ["", ""], start=o) - if o is None or e is None: - raise ValueError("Invalid input/output") + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + if o is None: + raise ValueError("Missing output markers (should be correct in the prompt)") + e = next_marker(seq, ["", ""], start=o) + if e is None: + raise ValueError( + "Missing input/output markers (should be correct in the prompt)" + ) try: io.append( ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) ) except ValueError: - raise ValueError("Invalid input/output") + raise ValueError( + "Invalid input/output value (should be correct in the prompt)" + ) k = e - if seq[k] == "": + if seq[k] == "": e = next_marker(seq, [""], start=k) if e is None: prog = [] else: prog = seq[k + 1 : e] + else: + raise ValueError("Missing (it should be in the prompt)") + return prog, io +def stack_distance(target_stack, result_stack): + return abs(len(result_stack) - len(target_stack)) + sum( + [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] + ) + + def compute_nb_errors(seq): prog, io = decompose(seq) @@ -116,7 +151,7 @@ def compute_nb_errors(seq): if len(set(prog) - set(rpl_ops)) > 0: # Program is not valid, we count 100% error for start_stack, target_stack in io: - stacks.append((start_stack, target_stack, "N/A", False)) + stacks.append((start_stack, target_stack, ["N/A"], False)) nb_total += len(target_stack) nb_errors += len(target_stack) @@ -125,9 +160,7 @@ def compute_nb_errors(seq): for start_stack, target_stack in io: result_stack = rpl_exec(prog, start_stack) nb_total += len(target_stack) - e = abs(len(result_stack) - len(target_stack)) + sum( - [0 if x == y else 1 for x, y in zip(result_stack, target_stack)] - ) + e = stack_distance(target_stack, result_stack) nb_errors += e stacks.append((start_stack, target_stack, result_stack, e == 0))