X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=38dccb9f8eea57437ce7574d4f87208ab0077b38;hb=27dea0ab0448511236cb344c17162e84359a14ee;hp=314a96134e150cad93ec81d492bc78d7396a75c4;hpb=87da428a5ab9ac3cd49ab22bd27e572d0b16f29c;p=picoclvr.git diff --git a/main.py b/main.py index 314a961..38dccb9 100755 --- a/main.py +++ b/main.py @@ -113,7 +113,9 @@ parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) -parser.add_argument("--stack_nb_digits", type=int, default=1) +parser.add_argument("--stack_nb_digits", type=int, default=3) + +parser.add_argument("--stack_fraction_values_for_train", type=float, default=None) ###################################################################### @@ -876,6 +878,7 @@ class TaskStack(Task): nb_steps, nb_stacks, nb_digits, + fraction_values_for_train=None, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -884,12 +887,31 @@ class TaskStack(Task): self.nb_digits = nb_digits self.device = device + if fraction_values_for_train is None: + values_for_train = None + values_for_test = None + else: + all = torch.randperm(10**nb_digits) + nb_for_train = int(all.size(0) * fraction_values_for_train) + values_for_train = all[:nb_for_train] + values_for_test = all[nb_for_train:] + self.train_input, self.train_stack_counts = stack.generate_sequences( - nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_train_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_train, + self.device, ) self.test_input, self.test_stack_counts = stack.generate_sequences( - nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_test_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_test, + self.device, ) mask = self.test_input.clone() @@ -946,7 +968,9 @@ class TaskStack(Task): ) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - input = self.test_input[:10, :20] + l = 50 + l = l - l % (1 + self.nb_digits) + input = self.test_input[:10, :l] result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() @@ -1038,6 +1062,7 @@ elif args.task == "stack": nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, nb_digits=args.stack_nb_digits, + fraction_values_for_train=args.stack_fraction_values_for_train, device=device, )