X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=rpl.py;h=7e865a56556110b36e871a75451c72708807e45a;hb=439c597d409c344283f8996f042daf79d3f24de2;hp=155bc69cc76b4b88a3bd05c75cb2ce977e762df9;hpb=0c47d4d8ef8c4938f4765af816349cf30da14cb1;p=picoclvr.git diff --git a/rpl.py b/rpl.py index 155bc69..7e865a5 100755 --- a/rpl.py +++ b/rpl.py @@ -53,18 +53,28 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): - prog_len = 1 + torch.randint(prog_len - 1, (1,)).item() - prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] - - result = [] - for _ in range(nb_runs): - stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] - result_stack = rpl_exec(prog, stack) - result = result + [""] + stack + [""] + result_stack - - result = result + [""] + prog - result = result + [""] +def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): + prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() + + while True: + no_empty_stack = True + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [ + x.item() for x in torch.randint(max_input + 1, (nb_starting_values,)) + ] + result_stack = rpl_exec(prog, stack) + if len(result_stack) == 0: + no_empty_stack = False + result = result + [""] + stack + [""] + result_stack + + result = result + [""] + prog + result = result + [""] + if no_empty_stack: + break + return result @@ -116,7 +126,7 @@ def compute_nb_errors(seq): if len(set(prog) - set(rpl_ops)) > 0: # Program is not valid, we count 100% error for start_stack, target_stack in io: - stacks.append((start_stack, target_stack, "N/A", False)) + stacks.append((start_stack, target_stack, ["N/A"], False)) nb_total += len(target_stack) nb_errors += len(target_stack)