Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 17:03:23 +0000 (20:03 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 17:03:23 +0000 (20:03 +0300)
lang.py

diff --git a/lang.py b/lang.py
index abb7ca2..1472e04 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -283,6 +283,16 @@ class Lang(problem.Problem):
                 if n == N - 1:
                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
+    def task_detect(self, A, f_A, B, f_B):
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(X, N)
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1, j1] = c[-1]
+
     ######################################################################
 
     def generate_prompts_and_answers(self, nb):
@@ -292,6 +302,7 @@ class Lang(problem.Problem):
             self.task_grow,
             self.task_color_grow,
             self.task_frame,
+            self.task_detect,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)