From 9f7966f5481f31ef9a531f4d2780f6b669499a71 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 20:03:23 +0300 Subject: [PATCH] Update. --- lang.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lang.py b/lang.py index abb7ca2..1472e04 100755 --- a/lang.py +++ b/lang.py @@ -283,6 +283,16 @@ class Lang(problem.Problem): if n == N - 1: f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + def task_detect(self, A, f_A, B, f_B): + N = 3 + c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(X, N) + for n in range(N): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1, j1] = c[-1] + ###################################################################### def generate_prompts_and_answers(self, nb): @@ -292,6 +302,7 @@ class Lang(problem.Problem): self.task_grow, self.task_color_grow, self.task_frame, + self.task_detect, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) -- 2.39.5