From ae80c484bdecdd43ea7e335365040edb73c22edb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 17 Jul 2024 15:43:50 +0200 Subject: [PATCH] Update. --- grids.py | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/grids.py b/grids.py index e4d831c..7050b77 100755 --- a/grids.py +++ b/grids.py @@ -1131,6 +1131,45 @@ class Grids(problem.Problem): f_X[i, j + 5] = c[M2[i, j]] f_X[i + 5, j + 5] = c[P[i, j]] + def task_compute(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + v = torch.randint((self.width - 1) // 2, (N,)) + 1 + chain = torch.randperm(N) + eq = [] + for i in range(chain.size(0) - 1): + i1, i2 = chain[i], chain[i + 1] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + eq.append((c[i1], w1, c[i2], w2)) + + ii = torch.randperm(len(eq)) + + for k, x in enumerate(eq): + i = ii[k] + c1, w1, c2, w2 = x + X[i, 0:w1] = c1 + X[i, w1 : w1 + w2] = c2 + f_X[i, 0:w1] = c1 + f_X[i, w1 : w1 + w2] = c2 + + i1, i2 = torch.randperm(N)[:2] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + c1, c2 = c[i1], c[i2] + i = self.height - 1 + X[i, 0:w1] = c1 + X[i, w1 : w1 + 1] = c2 + f_X[i, 0:w1] = c1 + f_X[i, w1 : w1 + w2] = c2 + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): @@ -1221,7 +1260,7 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_matrices]: + for t in [grids.task_compute]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) grids.save_quiz_illustrations( -- 2.39.5