# CODE_VAL=val + 2 * nb_stacks
-def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
+def generate_sequences(
+ nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
- val_push = torch.randint(10**nb_digits, (nb,))
+ if values is None:
+ val_push = torch.randint(10**nb_digits, (nb,))
+ else:
+ val_push = values[torch.randint(values.size(0), (nb,))]
val_pop = stack[
k,
st,
######################################################################
if __name__ == "__main__":
- nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1
+ nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
seq, recorded_stack_counts = generate_sequences(
nb=nb,
nb_steps=nb_steps,
nb_digits=nb_digits,
)
- print("-- TRAIN -----------------------------")
-
for n in range(min(10, seq.size(0))):
# print(seq_to_str(seq[n], recorded_stack_counts[n]))
print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
- print("-- TEST ------------------------------")
+ print("-- PREPARED FOR TEST -----------------")
remove_popped_values(seq, nb_stacks, nb_digits)