From b718ef527d4bfb014a9ad564bb5199c7d0780aa9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 24 Jul 2023 16:15:32 -1000 Subject: [PATCH] Update. --- main.py | 12 ++++++++++-- problems.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index af94979..68b946a 100755 --- a/main.py +++ b/main.py @@ -42,6 +42,10 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) + +######################################## + parser.add_argument("--nb_epochs", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) @@ -56,6 +60,8 @@ parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") +######################################## + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -70,6 +76,8 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) +######################################## + parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--no_checkpoint", action="store_true", default=False) @@ -570,8 +578,8 @@ log_string( ) assert ( - nb_in_train <= nb_test // 100 -), "More than 1% of test samples are in the train set" + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 +), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" ############################## diff --git a/problems.py b/problems.py index dca201f..7b1d698 100755 --- a/problems.py +++ b/problems.py @@ -21,6 +21,38 @@ class Problem: #################### +class ProblemTwoTargets(Problem): + def __init__(self, len_total=10, len_target=2): + assert len_total >= 3 * (2 + len_target) - 1 + self.len_total = len_total + self.len_target = len_target + + def generate_sequences(self, nb): + k = torch.arange(self.len_total)[None, :] + l = torch.randint(self.len_total, (2, nb))[:, :, None] + 1 + i = torch.randint(10, (2, nb))[:, :, None] + a = l[0] + b = l[0] + 1 + l[1] + c = l[0] + 1 + l[1] + 1 + l[0] + sequences = ( + (k < a) * i[0] + + (k == a) * 10 + + (k > a) * (k < b) * i[1] + + (k == b) * 11 + + (k > b) * (k < c) * i[1] + + (k >= c) * 12 + ) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>_"[x.item()] for x in seq) + + +#################### + + class ProblemLenId(Problem): def __init__(self, len_max=10): self.len_max = len_max -- 2.39.5