X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=e00ee665aaaec075a177c5776fa2d0093a2b704f;hb=0b147af672d69d5fca328bc937467993c22fb20d;hp=c1f4dc7f0540e2dcbbdf8c71b9a3c1ca29db457b;hpb=f91736e6e56152746b3c44342748b70ad1c89888;p=picoclvr.git diff --git a/main.py b/main.py index c1f4dc7..e00ee66 100755 --- a/main.py +++ b/main.py @@ -922,7 +922,7 @@ class TaskStack(Task): i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks) counts = self.test_stack_counts.flatten()[i.flatten()] counts = F.one_hot(counts).sum(0) - log_string(f"pop_stack_counts {counts}") + log_string(f"test_pop_stack_counts {counts}") self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1