From 4502a109727b0424ff6d4df90f17b361524f9e73 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 2 Jul 2023 18:34:10 +0200 Subject: [PATCH] Update. --- main.py | 31 +++++++++++++++++++++++++++---- stack.py | 9 +++++++-- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 314a961..bb1e7b4 100755 --- a/main.py +++ b/main.py @@ -113,7 +113,9 @@ parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) -parser.add_argument("--stack_nb_digits", type=int, default=1) +parser.add_argument("--stack_nb_digits", type=int, default=3) + +parser.add_argument("--stack_fraction_values_for_train", type=float, default=None) ###################################################################### @@ -876,6 +878,7 @@ class TaskStack(Task): nb_steps, nb_stacks, nb_digits, + fraction_values_for_train=None, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -884,12 +887,31 @@ class TaskStack(Task): self.nb_digits = nb_digits self.device = device + if fraction_values_for_train is None: + values_for_train = None + values_for_test = None + else: + all = torch.randperm(10**nb_digits) + nb_for_train = int(all.size(0) * fraction_values_for_train) + values_for_train = all[:nb_for_train] + values_for_test = all[nb_for_train:] + self.train_input, self.train_stack_counts = stack.generate_sequences( - nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_train_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_train, + self.device, ) self.test_input, self.test_stack_counts = stack.generate_sequences( - nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device + nb_test_samples, + nb_steps, + nb_stacks, + nb_digits, + values_for_test, + self.device, ) mask = self.test_input.clone() @@ -946,7 +968,7 @@ class TaskStack(Task): ) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - input = self.test_input[:10, :20] + input = self.test_input[:10, :50] result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() @@ -1038,6 +1060,7 @@ elif args.task == "stack": nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, nb_digits=args.stack_nb_digits, + fraction_values_for_train=args.stack_fraction_values_for_train, device=device, ) diff --git a/stack.py b/stack.py index 675182e..219a1ad 100755 --- a/stack.py +++ b/stack.py @@ -13,7 +13,9 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")): +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) @@ -26,7 +28,10 @@ def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device(" op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_counts[k, st] > 0) - val_push = torch.randint(10**nb_digits, (nb,)) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] val_pop = stack[ k, st, -- 2.39.5