Update.
[culture.git] / reasoning.py
index 9e26d64..c545e97 100755 (executable)
@@ -465,19 +465,95 @@ class Reasoning(problem.Problem):
                 for j in range(nb[n]):
                     f_X[n, j] = c[n]
 
                 for j in range(nb[n]):
                     f_X[n, j] = c[n]
 
-    def task_count_(self, A, f_A, B, f_B):
-        N = torch.randint(3, (1,)) + 1
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+    def task_trajectory(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
         for X, f_X in [(A, f_A), (B, f_B)]:
-            nb = torch.randint(self.width, (3,)) + 1
-            k = torch.randperm(self.height * self.width)[: nb.sum()]
-            p = 0
-            for n in range(N):
-                for m in range(nb[n]):
-                    i, j = k[p] % self.height, k[p] // self.height
-                    X[i, j] = c[n]
-                    f_X[n, m] = c[n]
-                    p += 1
+            while True:
+                di, dj = torch.randint(7, (2,)) - 3
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                if (
+                    abs(di) + abs(dj) > 0
+                    and i + 2 * di >= 0
+                    and i + 2 * di < self.height
+                    and j + 2 * dj >= 0
+                    and j + 2 * dj < self.width
+                ):
+                    break
+
+            k = 0
+            while (
+                i + k * di >= 0
+                and i + k * di < self.height
+                and j + k * dj >= 0
+                and j + k * dj < self.width
+            ):
+                if k < 2:
+                    X[i + k * di, j + k * dj] = c[k]
+                f_X[i + k * di, j + k * dj] = c[min(k, 1)]
+                k += 1
+
+    def task_bounce(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+
+            def free(i, j):
+                return (
+                    i >= 0
+                    and i < self.height
+                    and j >= 0
+                    and j < self.width
+                    and f_X[i, j] == 0
+                )
+
+            while True:
+                f_X[...] = 0
+                X[...] = 0
+
+                for _ in range((self.height * self.width) // 10):
+                    i, j = torch.randint(self.height, (1,)), torch.randint(
+                        self.width, (1,)
+                    )
+                    X[i, j] = c[0]
+                    f_X[i, j] = c[0]
+
+                while True:
+                    di, dj = torch.randint(7, (2,)) - 3
+                    if abs(di) + abs(dj) == 1:
+                        break
+
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+
+                X[i, j] = c[1]
+                f_X[i, j] = c[1]
+                l = 0
+
+                while True:
+                    l += 1
+                    if free(i + di, j + dj):
+                        pass
+                    elif free(i - dj, j + di):
+                        di, dj = -dj, di
+                        if free(i + dj, j - di):
+                            if torch.rand(1) < 0.5:
+                                di, dj = -di, -dj
+                    elif free(i + dj, j - di):
+                        di, dj = dj, -di
+                    else:
+                        break
+
+                    i, j = i + di, j + dj
+                    f_X[i, j] = c[2]
+                    if l <= 1:
+                        X[i, j] = c[2]
+
+                    if l >= self.width:
+                        break
+
+                f_X[i, j] = c[1]
+                X[i, j] = c[1]
+
+                if l > 3:
+                    break
 
     ######################################################################
 
 
     ######################################################################
 
@@ -490,6 +566,8 @@ class Reasoning(problem.Problem):
             self.task_frame,
             self.task_detect,
             self.task_count,
             self.task_frame,
             self.task_detect,
             self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)