From cd5e4647e105a10012d687169d49bec0343e274f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 20 Jul 2023 14:11:54 +0200 Subject: [PATCH] Update. --- rpl.py | 9 +++++++-- tasks.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/rpl.py b/rpl.py index 7c1c96e..8d31efe 100755 --- a/rpl.py +++ b/rpl.py @@ -58,7 +58,9 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] ###################################################################### -def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): +def generate( + nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5 +): prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item() while True: @@ -77,7 +79,10 @@ def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5): result = result + [""] + prog result = result + [""] - if no_empty_stack: + + if no_empty_stack and ( + nb_result_values_max is None or len(result_stack) <= nb_result_values_max + ): break return result diff --git a/tasks.py b/tasks.py index 889d4a9..0827a44 100755 --- a/tasks.py +++ b/tasks.py @@ -1070,6 +1070,7 @@ class RPL(Task): train_sequences = [ rpl.generate( nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, max_input=max_input, prog_len=prog_len, nb_runs=nb_runs, @@ -1080,6 +1081,7 @@ class RPL(Task): test_sequences = [ rpl.generate( nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, max_input=max_input, prog_len=prog_len, nb_runs=nb_runs, -- 2.39.5