- stack[k, st, stack_pointers[k, st]] = val_push
- depth_counts[:, 2 * t + 1] = stack_pointers[k, st]
- stack_pointers[k[op == 0], st[op == 0]] += 1
- stack_pointers[k[op == 1], st[op == 1]] -= 1
- result[:, 2 * t] = st * 2 + op
- result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
+ stack[k, st, stack_counts[k, st]] = val_push
+ recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+ stack_counts[k[op == 0], st[op == 0]] += 1
+ stack_counts[k[op == 1], st[op == 1]] -= 1
+ result[:, (1 + nb_digits) * t] = st * 2 + op
+ for d in range(nb_digits):
+ result[:, (1 + nb_digits) * t + 1 + d] = (
+ (op * val_pop + (1 - op) * val_push) // (10**d)
+ ) % 10 + 2 * nb_stacks
+
+ return result.to(device), recorded_stack_counts.to(device)