From e781d77071fa26f393f50451f91c70f4a0850ca5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 00:59:24 +0200 Subject: [PATCH] Update. --- main.py | 6 +++--- tasks.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 0d4930d..19f918c 100755 --- a/main.py +++ b/main.py @@ -86,7 +86,7 @@ parser.add_argument("--sandbox_level", type=int, default=0) parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) -parser.add_argument("--sandbox_levels_len_source", type=int, default=5) +parser.add_argument("--sandbox_levels_len_source", type=int, default=6) parser.add_argument("--sandbox_levels_len_result", type=int, default=8) @@ -163,9 +163,9 @@ if args.result_dir is None: default_args = { "sandbox": { - "nb_epochs": 10, + "nb_epochs": 50, "batch_size": 25, - "nb_train_samples": 25000, + "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { diff --git a/tasks.py b/tasks.py index e7c2f75..c5418b4 100755 --- a/tasks.py +++ b/tasks.py @@ -104,7 +104,8 @@ class ProblemLevel1(Problem): // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) ) % 10 marker1 = torch.full((nb, 1), 10) - source = torch.randint(10, (nb, self.len_source)) + # source = torch.randint(10, (nb, self.len_source)) + source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] marker2 = torch.full((nb, 1), 11) result = operators.bmm(source[:, :, None]).squeeze(-1) print(f"{nb_operators.dtype=} {marker1.dtype=}") @@ -128,7 +129,8 @@ class ProblemLevel2(Problem): torch.rand(nb, self.len_result, self.len_source).argmax(-1), num_classes=self.len_source, ) - source1 = torch.randint(10, (nb, self.len_source)) + source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + # source1 = torch.randint(10, (nb, self.len_source)) marker1 = torch.full((nb, 1), 10) result1 = operators.bmm(source1[:, :, None]).squeeze(-1) marker2 = torch.full((nb, 1), 11) -- 2.39.5