From cc241681730e50ad149a68c612e3a06f2d4a71be Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 5 Jul 2024 09:10:31 +0300 Subject: [PATCH] Update. --- reasoning.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/reasoning.py b/reasoning.py index c545e97..cd726cb 100755 --- a/reasoning.py +++ b/reasoning.py @@ -555,6 +555,30 @@ class Reasoning(problem.Problem): if l > 3: break + def task_scale(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i, j = torch.randint(self.height // 2, (1,)), torch.randint( + self.width // 2, (1,) + ) + + for X, f_X in [(A, f_A), (B, f_B)]: + for _ in range(3): + while True: + i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: + break + X[i + i1 : i + i2, j + j1 : j + j2] = c[0] + f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] + + X[i, j] = c[1] + f_X[0:2, 0:2] = c[1] + ###################################################################### def generate_prompts_and_answers(self, nb, device="cpu"): @@ -568,6 +592,7 @@ class Reasoning(problem.Problem): self.task_count, self.task_trajectory, self.task_bounce, + self.task_scale, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) -- 2.20.1