X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=14b1bc346064733faa1022e80c5278005a1bc359;hb=c5daf2eeedb26a25789de370171d592c621a2fac;hp=0323d0218ec587066817d5359044176cab99692b;hpb=76671c582f029aa67fce2626764b02e8d9e2dbeb;p=picoclvr.git diff --git a/main.py b/main.py index 0323d02..14b1bc3 100755 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ parser.add_argument("--snake_length", type=int, default=200) ############################## # Snake options -parser.add_argument("--stack_nb_steps", type=int, default=25) +parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) @@ -166,9 +166,9 @@ default_args = { "nb_test_samples": 10000, }, "stack": { - "nb_epochs": 25, + "nb_epochs": 5, "batch_size": 25, - "nb_train_samples": 10000, + "nb_train_samples": 100000, "nb_test_samples": 1000, }, } @@ -892,6 +892,13 @@ class TaskStack(Task): nb_test_samples, nb_steps, nb_stacks, nb_values, self.device ) + mask = self.test_input.clone() + stack.remove_poped_values(mask,self.nb_stacks) + mask=(mask!=self.test_input) + counts = self.test_stack_counts.flatten()[mask.flatten()] + counts=F.one_hot(counts).sum(0) + log_string(f"stack_count {counts}") + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None):