for t in range(nb_steps):
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
for t in range(nb_steps):
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)