X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=4770a125172cf77a41aada959ef5f18b44517be4;hb=757876d57637e0da35f3680ec6ac9573b91f902a;hp=314a96134e150cad93ec81d492bc78d7396a75c4;hpb=87da428a5ab9ac3cd49ab22bd27e572d0b16f29c;p=picoclvr.git diff --git a/main.py b/main.py index 314a961..4770a12 100755 --- a/main.py +++ b/main.py @@ -37,7 +37,7 @@ parser.add_argument( parser.add_argument("--log_filename", type=str, default="train.log", help=" ") -parser.add_argument("--result_dir", type=str, default="results_default") +parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) @@ -113,7 +113,9 @@ parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) -parser.add_argument("--stack_nb_digits", type=int, default=1) +parser.add_argument("--stack_nb_digits", type=int, default=3) + +parser.add_argument("--stack_fraction_values_for_train", type=float, default=None) ###################################################################### @@ -121,6 +123,8 @@ args = parser.parse_args() assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +if args.result_dir is None: args.result_dir=f"results_{args.task}" + try: os.mkdir(args.result_dir) except FileExistsError: @@ -876,6 +880,7 @@ class TaskStack(Task): nb_steps, nb_stacks, nb_digits, + fraction_values_for_train=None, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -884,12 +889,31 @@ class TaskStack(Task): self.nb_digits = nb_digits self.device = device + if fraction_values_for_train is None: + values_for_train = None + values_for_test = None + else: + all = torch.randperm(10**nb_digits) + nb_for_train = int(all.size(0) * fraction_values_for_train) + values_for_train = all[:nb_for_train] + values_for_test = all[nb_for_train:] + self.train_input, self.train_stack_counts = stack.generate_sequences( - nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_train_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_train, + self.device, ) self.test_input, self.test_stack_counts = stack.generate_sequences( - nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_test_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_test, + self.device, ) mask = self.test_input.clone() @@ -946,7 +970,9 @@ class TaskStack(Task): ) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - input = self.test_input[:10, :20] + l = 50 + l = l - l % (1 + self.nb_digits) + input = self.test_input[:10, :l] result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() @@ -1038,6 +1064,7 @@ elif args.task == "stack": nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, nb_digits=args.stack_nb_digits, + fraction_values_for_train=args.stack_fraction_values_for_train, device=device, )