+ seq, recorded_stack_counts = generate_sequences(
+ nb=3,
+ nb_steps=6,
+ nb_stacks=3,
+ nb_digits=3,
+ )
+
+ sep = torch.full((seq.size(0), 1), seq.max() + 1)
+
+ seq = torch.cat([seq, sep, seq], dim=1)
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ remove_popped_values(seq, 3, 3)
+
+ print()
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ exit(0)
+