X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=rpl.py;h=f826fc4fd5fbe33eed919e8d0f5d80220047b73b;hb=d7eeacf1eab237bbbe67d3e44b90b57fd1445667;hp=7e865a56556110b36e871a75451c72708807e45a;hpb=439c597d409c344283f8996f042daf79d3f24de2;p=picoclvr.git diff --git a/rpl.py b/rpl.py index 7e865a5..f826fc4 100755 --- a/rpl.py +++ b/rpl.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -53,7 +58,9 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() while True: @@ -72,7 +79,10 @@ def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): result = result + [""] + prog result = result + [""] - if no_empty_stack: + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): break return result @@ -97,13 +107,17 @@ def decompose(seq): o = next_marker(seq, [""], start=k + 1) e = next_marker(seq, ["", ""], start=o) if o is None or e is None: - raise ValueError("Invalid input/output") + raise ValueError( + "Missing input/output markers (should be correct in the prompt)" + ) try: io.append( ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]]) ) except ValueError: - raise ValueError("Invalid input/output") + raise ValueError( + "Invalid input/output value (should be correct in the prompt)" + ) k = e @@ -113,6 +127,9 @@ def decompose(seq): prog = [] else: prog = seq[k + 1 : e] + else: + raise ValueError("Missing (it should be in the prompt)") + return prog, io