######################################################################
-def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5):
+def generate(
+ nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
+):
prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
while True:
result = result + ["<prog>"] + prog
result = result + ["<end>"]
- if no_empty_stack:
+
+ if no_empty_stack and (
+ nb_result_values_max is None or len(result_stack) <= nb_result_values_max
+ ):
break
return result
train_sequences = [
rpl.generate(
nb_starting_values=nb_starting_values,
+ nb_result_values_max=4 * nb_starting_values,
max_input=max_input,
prog_len=prog_len,
nb_runs=nb_runs,
test_sequences = [
rpl.generate(
nb_starting_values=nb_starting_values,
+ nb_result_values_max=4 * nb_starting_values,
max_input=max_input,
prog_len=prog_len,
nb_runs=nb_runs,