From ef1517c154a3b9a151ddb2d375a168cf9cf05b85 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 08:54:27 +0200 Subject: [PATCH] Update. --- grids.py | 70 +++++++++++++++++++++++++++++-------------------- main.py | 22 +++++----------- quiz_machine.py | 37 ++++++++++++++++---------- 3 files changed, 72 insertions(+), 57 deletions(-) diff --git a/grids.py b/grids.py index 8d274ad..f12fcb9 100755 --- a/grids.py +++ b/grids.py @@ -167,6 +167,19 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct + def inject_noise(self, quizzes, noise, struct, mask): + assert self.check_structure(quizzes, struct=struct) + S = self.height * self.width + mask = torch.tensor(mask, device=quizzes.device) + mask = mask[None, :, None].expand(1, 4, S + 1) + mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() + mask = mask.reshape(1, -1).expand_as(quizzes) + random = torch.randint(self.nb_colors, mask.size()) + + quizzes = mask * random + (1 - mask) * quizzes + + return quizzes + # What a mess def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): if torch.is_tensor(quizzes): @@ -237,7 +250,8 @@ class Grids(problem.Problem): ): self.colors = torch.tensor([c for _, c in self.named_colors]) - self.token_A = len(self.colors) + self.nb_colors = len(self.colors) + self.token_A = self.nb_colors self.token_f_A = self.token_A + 1 self.token_B = self.token_f_A + 1 self.token_f_B = self.token_B + 1 @@ -294,7 +308,7 @@ class Grids(problem.Problem): ###################################################################### def grid2img(self, x, scale=15): - m = torch.logical_and(x >= 0, x < len(self.colors)).long() + m = torch.logical_and(x >= 0, x < self.nb_colors).long() y = self.colors[x * m].permute(0, 3, 1, 2) s = y.shape y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) @@ -562,7 +576,7 @@ class Grids(problem.Problem): # @torch.compile def task_replace_color(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): @@ -578,7 +592,7 @@ class Grids(problem.Problem): break nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: r = self.rec_coo(nb_rec, prevent_overlap=True) @@ -603,7 +617,7 @@ class Grids(problem.Problem): def task_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 direction = torch.randint(2, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: while True: @@ -629,7 +643,7 @@ class Grids(problem.Problem): def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1 direction = torch.randint(4, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: r = self.rec_coo(nb_rec, prevent_overlap=True) @@ -670,7 +684,7 @@ class Grids(problem.Problem): # @torch.compile def task_frame(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): @@ -687,7 +701,7 @@ class Grids(problem.Problem): # @torch.compile def task_detect(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): @@ -739,7 +753,7 @@ class Grids(problem.Problem): N = 3 c = torch.zeros(N + 2, dtype=torch.int64) - c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: if not hasattr(self, "cache_count") or len(self.cache_count) == 0: @@ -785,7 +799,7 @@ class Grids(problem.Problem): # @torch.compile def task_trajectory(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: di, dj = torch.randint(7, (2,)) - 3 @@ -816,7 +830,7 @@ class Grids(problem.Problem): # @torch.compile def task_bounce(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:3] + 1 + c = torch.randperm(self.nb_colors - 1)[:3] + 1 for X, f_X in [(A, f_A), (B, f_B)]: # @torch.compile def free(i, j): @@ -885,7 +899,7 @@ class Grids(problem.Problem): # @torch.compile def task_scale(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 i, j = ( torch.randint(self.height // 2, (1,)).item(), @@ -915,7 +929,7 @@ class Grids(problem.Problem): # @torch.compile def task_symbols(self, A, f_A, B, f_B): nb_rec = 4 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 delta = 3 for X, f_X in [(A, f_A), (B, f_B)]: while True: @@ -967,7 +981,7 @@ class Grids(problem.Problem): X[...] = 0 f_X[...] = 0 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for r in range(nb_rec): while True: @@ -1034,7 +1048,7 @@ class Grids(problem.Problem): # @torch.compile def REMOVED_task_distance(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:3] + 1 + c = torch.randperm(self.nb_colors - 1)[:3] + 1 dist0 = torch.empty(self.height + 2, self.width + 2) dist1 = torch.empty(self.height + 2, self.width + 2) for X, f_X in [(A, f_A), (B, f_B)]: @@ -1099,7 +1113,7 @@ class Grids(problem.Problem): def TOO_HARD_task_puzzle(self, A, f_A, B, f_B): S = 4 i0, j0 = (self.height - S) // 2, (self.width - S) // 2 - c = torch.randperm(len(self.colors) - 1)[:4] + 1 + c = torch.randperm(self.nb_colors - 1)[:4] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: f_X[...] = 0 @@ -1165,7 +1179,7 @@ class Grids(problem.Problem): X[ii + i, jj + j] = c[d] def TOO_MESSY_task_islands(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: self.cache_islands = list( @@ -1196,7 +1210,7 @@ class Grids(problem.Problem): # @torch.compile def TOO_HARD_task_stack(self, A, f_A, B, f_B): N = 5 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: i1, j1, i2, j2 = ( self.height // 2 - 1, @@ -1241,7 +1255,7 @@ class Grids(problem.Problem): def TOO_HARD_task_matrices(self, A, f_A, B, f_B): N = 6 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: M1 = torch.randint(2, (5, 5)) @@ -1257,7 +1271,7 @@ class Grids(problem.Problem): def TOO_HARD_task_compute(self, A, f_A, B, f_B): N = 6 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: v = torch.randint((self.width - 1) // 2, (N,)) + 1 chain = torch.randperm(N) @@ -1307,7 +1321,7 @@ class Grids(problem.Problem): return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0)) nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: r = self.rec_coo(nb_rec, prevent_overlap=True) @@ -1330,7 +1344,7 @@ class Grids(problem.Problem): def task_corners(self, A, f_A, B, f_B): polarity = torch.randint(2, (1,)).item() nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: r = self.rec_coo(nb_rec, prevent_overlap=True) @@ -1367,7 +1381,7 @@ class Grids(problem.Problem): # @torch.compile def task_path(self, A, f_A, B, f_B): nb_rec = 2 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: X[...] = 0 @@ -1403,7 +1417,7 @@ class Grids(problem.Problem): # @torch.compile def task_fill(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: accept_full = torch.rand(1) < 0.5 @@ -1435,7 +1449,7 @@ class Grids(problem.Problem): break def TOO_HARD_task_addition(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:4] + 1 + c = torch.randperm(self.nb_colors - 1)[:4] + 1 for X, f_X in [(A, f_A), (B, f_B)]: N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() @@ -1452,7 +1466,7 @@ class Grids(problem.Problem): def task_science_implicit(self, A, f_A, B, f_B): nb_rec = 5 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: @@ -1505,7 +1519,7 @@ class Grids(problem.Problem): def task_science_dot(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: X[...] = 0 @@ -1539,7 +1553,7 @@ class Grids(problem.Problem): return False def task_science_tag(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:4] + 1 + c = torch.randperm(self.nb_colors - 1)[:4] + 1 for X, f_X in [(A, f_A), (B, f_B)]: rs = [] while len(rs) < 4: diff --git a/main.py b/main.py index cce747a..72b2b26 100755 --- a/main.py +++ b/main.py @@ -43,8 +43,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) parser.add_argument("--log_command", type=str, default=None) -######################################## - +# ---------------------------------- parser.add_argument("--nb_epochs", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=None) @@ -63,8 +62,7 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=5e-4) -######################################## - +# ---------------------------------- parser.add_argument("--model", type=str, default=None) parser.add_argument("--dim_model", type=int, default=None) @@ -79,8 +77,7 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) -######################################## - +# ---------------------------------- parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--problem", type=str, default="grids") @@ -89,6 +86,8 @@ parser.add_argument("--nb_threads", type=int, default=1) parser.add_argument("--gpus", type=str, default="all") +# ---------------------------------- + parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--max_fail_to_validate", type=int, default=2) @@ -103,7 +102,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=1) -parser.add_argument("--nb_rounds", type=int, default=1) +parser.add_argument("--prompt_noise", type=float, default=0.0) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -343,6 +342,7 @@ quiz_machine = quiz_machine.QuizMachine( problem=problem, batch_size=args.inference_batch_size, result_dir=args.result_dir, + prompt_noise=args.prompt_noise, logger=log_string, device=main_device, ) @@ -578,14 +578,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[to_keep] - # We go through nb_rounds rounds and keep only quizzes on - # which - # - # (1) models respond always the same through rounds, and - # - # (2) at least one and up to max_fail_to_validate model(s) - # fail(s) - # This is nb_quizzes x nb_models seq_logproba = quiz_machine.models_logprobas( diff --git a/quiz_machine.py b/quiz_machine.py index 90879ce..1e973c5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -67,6 +67,7 @@ class QuizMachine: problem, batch_size, result_dir, + prompt_noise, logger, device=torch.device("cpu"), ): @@ -78,6 +79,7 @@ class QuizMachine: self.logger = logger self.prompt_len = None self.answer_len = None + self.prompt_noise = prompt_noise self.understood_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), @@ -153,6 +155,7 @@ class QuizMachine: if len(c_quizzes) > 0: c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] c_quizzes = c_quizzes[i] @@ -171,13 +174,19 @@ class QuizMachine: quizzes = w_quizzes.clone() from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) - self.randomize_configuations_inplace( - quizzes, structs=[s for s, m in self.understood_structures] + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes, from_w = quizzes[i], from_w[i] + + if self.prompt_noise > 0.0: + quizzes = self.problem.inject_noise( + quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0) ) - i = torch.randperm(quizzes.size(0), device=quizzes.device) + self.randomize_configuations_inplace( + quizzes, structs=[s for s, m in self.understood_structures] + ) - return quizzes[i], from_w[i] + return quizzes, from_w ###################################################################### @@ -267,20 +276,24 @@ class QuizMachine: def renew_train_w_quizzes(self, model): if hasattr(model, "hard_w_quizzes"): - if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): + hard_w_quizzes = self.problem.reconfigure( + model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B") + ) + self.logger( + f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" + ) + if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): nb_to_generate = 0 - model.train_w_quizzes[...] = model.hard_w_quizzes[ + model.train_w_quizzes[...] = hard_w_quizzes[ torch.randperm(hard_w_quizzes.size(0))[ model.train_w_quizzes.size(0) ] ] else: - nb_to_generate = model.train_w_quizzes.size( - 0 - ) - model.hard_w_quizzes.size(0) + nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0) model.train_w_quizzes[...] = torch.cat( [ - model.hard_w_quizzes, + hard_w_quizzes, self.problem.generate_w_quizzes(nb_to_generate), ], dim=0, @@ -291,10 +304,6 @@ class QuizMachine: model.train_w_quizzes.size(0) ) - self.logger( - f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" - ) - self.randomize_configuations_inplace( model.train_w_quizzes, structs=[s for s, m in self.understood_structures] ) -- 2.39.5