From e6037cb4b23c6486e758be4f77adf95e0827e3e8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jul 2024 09:42:55 +0200 Subject: [PATCH] Update. --- grids.py | 36 +++++++++++++++++++++++++++++++++++- main.py | 2 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/grids.py b/grids.py index ae3a71e..d41ec49 100755 --- a/grids.py +++ b/grids.py @@ -1520,6 +1520,40 @@ class Grids(problem.Problem): if q >= 2: break + def collide(self, s, r, rs): + i, j = r + for i2, j2 in rs: + if abs(i - i2) < s and abs(j - j2) < s: + return True + return False + + def task_science_tag(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + rs = [] + while len(rs) < 4: + i, j = ( + torch.randint(self.height - 3, (1,)).item(), + torch.randint(self.width - 3, (1,)).item(), + ) + if not self.collide(s=3, r=(i, j), rs=rs): + rs.append((i, j)) + + for k in range(len(rs)): + i, j = rs[k] + q = min(k, 2) + X[i, j : j + 3] = c[q] + X[i + 2, j : j + 3] = c[q] + X[i : i + 3, j] = c[q] + X[i : i + 3, j + 2] = c[q] + + f_X[i, j : j + 3] = c[q] + f_X[i + 2, j : j + 3] = c[q] + f_X[i : i + 3, j] = c[q] + f_X[i : i + 3, j + 2] = c[q] + if q == 2: + f_X[i + 1, j + 1] = c[-1] + # end_tasks ###################################################################### @@ -1618,7 +1652,7 @@ if __name__ == "__main__": # for t in grids.all_tasks: - for t in [grids.task_science_dot]: + for t in [grids.task_science_tag]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( diff --git a/main.py b/main.py index 1b68eca..bf617b5 100755 --- a/main.py +++ b/main.py @@ -552,7 +552,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 rejected = [] - to_keep == quiz_machine.problem.trivial(c_quizzes) == False + to_keep = quiz_machine.problem.trivial(c_quizzes) == False if not to_keep.all(): rejected.append(c_quizzes[to_keep == False]) -- 2.39.5