i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
counts = self.test_stack_counts.flatten()[i.flatten()]
counts = F.one_hot(counts).sum(0)
- log_string(f"pop_stack_counts {counts}")
+ log_string(f"test_pop_stack_counts {counts}")
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1