From 5b7022591f48382ec84b1dda17297b1ed15166d7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 14 Jul 2024 10:49:09 +0200 Subject: [PATCH] Update. --- grids.py | 83 ++++++++++++++----------------------------------- quiz_machine.py | 3 -- 2 files changed, 24 insertions(+), 62 deletions(-) diff --git a/grids.py b/grids.py index aa21543..e651940 100755 --- a/grids.py +++ b/grids.py @@ -143,7 +143,7 @@ class Grids(problem.Problem): self.task_scale, self.task_symbols, self.task_isometry, - # self.task_path, + self.task_islands, ] if tasks is None: @@ -628,8 +628,8 @@ class Grids(problem.Problem): 1000, self.height, self.width, - nb_seeds=self.height * self.width // 9, - nb_iterations=self.height * self.width // 20, + nb_seeds=self.height * self.width // 8, + nb_iterations=self.height * self.width // 10, ) ) @@ -1051,65 +1051,30 @@ class Grids(problem.Problem): def task_islands(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - while True: - k = torch.randperm(self.height * self.width) - Z = torch.zeros(self.height + 2, self.width + 2) - - i0, j0 = ( - torch.randint(self.height, (1,)).item() + 1, - torch.randint(self.width, (1,)).item() + 1, + if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: + self.cache_islands = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 20, + nb_iterations=self.height * self.width // 2, + ) ) - Z[i0 - 1 : i0 + 2, j0 - 1 : j0 + 2] = 1 - - nb = 9 - - for q in k: - i, j = q % self.height + 1, q // self.height + 1 - - if Z[i, j] == 0: - r, s, t, u, v, w, x, y = ( - Z[i - 1, j], - Z[i - 1, j + 1], - Z[i, j + 1], - Z[i + 1, j + 1], - Z[i + 1, j], - Z[i + 1, j - 1], - Z[i, j - 1], - Z[i - 1, j - 1], - ) - - if ( - (nb < 16 or r + s + t + u + v + w + x + y > 0) - and (s == 0 or r + t > 0) - and (u == 0 or t + v > 0) - and (w == 0 or x + v > 0) - and (y == 0 or x + r > 0) - ): - # if r+s+t+u+v+w+x+y==0: - Z[i, j] = 1 - nb += 1 - - if nb == self.height * self.width // 2: - break - - if nb == self.height * self.width // 2: - break - - M = Z.clone() - Z[i0, j0] = 2 - X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1] + A = self.cache_islands.pop() - for _ in range(self.height + self.width): - Z[1:-1, 1:-1] = Z[1:-1, 1:-1].maximum( - torch.maximum( - torch.maximum(Z[0:-2, 1:-1], Z[2:, 1:-1]), - torch.maximum(Z[1:-1, 0:-2], Z[1:-1, 2:]), - ) + while True: + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), ) - Z *= M + if A[i, j] > 0: + break - f_X[...] = (Z[1:-1, 1:-1] == 1) * c[0] + (Z[1:-1, 1:-1] == 2) * c[1] + X[...] = (A > 0) * c[0] + X[i, j] = c[1] + f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0] ###################################################################### @@ -1201,7 +1166,7 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_count]: + for t in [grids.task_islands]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) grids.save_quiz_illustrations( @@ -1213,7 +1178,7 @@ if __name__ == "__main__": nb = 1000 # for t in grids.all_tasks: - for t in [grids.task_count]: + for t in [grids.task_islands]: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/quiz_machine.py b/quiz_machine.py index ef766c4..c49ecf2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -219,9 +219,6 @@ class QuizMachine: def generate_token_sequences(self, nb): prompts, answers = self.problem.generate_prompts_and_answers(nb) - print(f"DEBUG {prompts.size()=} {answers.size()=}") - sys.stdout.flush() - if self.prompt_len is None: self.prompt_len = prompts.size(1) -- 2.20.1