Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 5 Jul 2024 06:10:31 +0000 (09:10 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 5 Jul 2024 06:10:31 +0000 (09:10 +0300)
reasoning.py

index c545e97..cd726cb 100755 (executable)
@@ -555,6 +555,30 @@ class Reasoning(problem.Problem):
                 if l > 3:
                     break
 
+    def task_scale(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+        i, j = torch.randint(self.height // 2, (1,)), torch.randint(
+            self.width // 2, (1,)
+        )
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            for _ in range(3):
+                while True:
+                    i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
+                        break
+                X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
+                f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
+
+            X[i, j] = c[1]
+            f_X[0:2, 0:2] = c[1]
+
     ######################################################################
 
     def generate_prompts_and_answers(self, nb, device="cpu"):
@@ -568,6 +592,7 @@ class Reasoning(problem.Problem):
             self.task_count,
             self.task_trajectory,
             self.task_bounce,
+            self.task_scale,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)