Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 16:34:26 +0000 (19:34 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 16:34:26 +0000 (19:34 +0300)
lang.py

diff --git a/lang.py b/lang.py
index 4c3c64f..43550d7 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -172,74 +172,116 @@ class Lang(problem.Problem):
     def nb_token_values(self):
         return len(self.colors)
 
-    def rec_coo(self, x):
+    def rec_coo(self, x, n, min_height=3, min_width=3):
         while True:
-            i1, i2 = torch.randint(x.size(0), (2,))
-            if i1 < i2 - 1:
-                break
-        while True:
-            j1, j2 = torch.randint(x.size(1), (2,))
-            if j1 < j2 - 1:
+            collision = x.new_zeros(x.size())
+            result = []
+            for _ in range(n):
+                while True:
+                    i1, i2 = torch.randint(x.size(0), (2,))
+                    if i1 + min_height <= i2:
+                        break
+                while True:
+                    j1, j2 = torch.randint(x.size(1), (2,))
+                    if j1 + min_width <= j2:
+                        break
+                collision[i1:i2, j1:j2] += 1
+                if collision.max() > 1:
+                    break
+                result.append((i1, j1, i2, j2))
+            if collision.max() == 1:
                 break
-        return i1, j1, i2, j2
+        return result
 
     ######################################################################
 
     def task_replace_color(self, A, f_A, B, f_B):
-        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
-        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
-            for _ in range(torch.randint(n, (1,)) + 1):
-                i1, j1, i2, j2 = self.rec_coo(X)
-                X[i1:i2, j1:j2] = c1
-                f_X[i1:i2, j1:j2] = c2
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(X, N)
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
     def task_move(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
-            c = torch.randperm(len(self.colors) - 1)[:1] + 1
-            for _ in range(torch.randint(n, (1,)) + 1):
-                while True:
-                    i1, j1, i2, j2 = self.rec_coo(X)
-                    if (
-                        i1 + di >= 0
-                        and i2 + di < X.size(0)
-                        and j1 + dj >= 0
-                        and j2 + dj < X.size(1)
-                    ):
-                        break
-
-                X[i1:i2, j1:j2] = c
-                f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(X, N)
+                i1, j1, i2, j2 = r[N - 1]
+                if (
+                    i1 + di >= 0
+                    and i2 + di < X.size(0)
+                    and j1 + dj >= 0
+                    and j2 + dj < X.size(1)
+                ):
+                    break
+
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if n == N - 1:
+                    f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
 
     def task_grow(self, A, f_A, B, f_B):
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
         direction = torch.randint(2, (1,))
-        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
-            c = torch.randperm(len(self.colors) - 1)[:1] + 1
-            for _ in range(torch.randint(n, (1,)) + 1):
-                while True:
-                    i1, j1, i2, j2 = self.rec_coo(X)
-                    if i1 + 3 < i2 and j1 + 3 < j2:
-                        break
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(X, N)
+                i1, j1, i2, j2 = r[N - 1]
+                if i1 + 3 < i2 and j1 + 3 < j2:
+                    break
+
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                if n == N - 1:
+                    if direction == 0:
+                        X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                        f_X[i1:i2, j1:j2] = c[n]
+                    else:
+                        X[i1:i2, j1:j2] = c[n]
+                        f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                else:
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
 
-                if direction == 0:
-                    X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
-                    f_X[i1:i2, j1:j2] = c
+    def task_color_grow(self, A, f_A, B, f_B):
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
+        direction = torch.randint(2, (1,))
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(X, N)
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
+                f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
+                X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
+                if n == N - 1:
+                    f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1]
                 else:
-                    X[i1:i2, j1:j2] = c
-                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+                    f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
 
     def task_frame(self, A, f_A, B, f_B):
-        direction = torch.randint(2, (1,))
-        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
-            c = torch.randperm(len(self.colors) - 1)[:1] + 1
-            for _ in range(torch.randint(n, (1,)) + 1):
-                while True:
-                    i1, j1, i2, j2 = self.rec_coo(X)
-                    if i1 + 3 < i2 and j1 + 3 < j2:
-                        break
-                X[i1:i2, j1:j2] = c
-                f_X[i1:i2, j1:j2] = c
-                f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+        N = 3
+        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(X, N)
+            for n in range(N):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+                if n == N - 1:
+                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
     ######################################################################
 
@@ -248,6 +290,7 @@ class Lang(problem.Problem):
             self.task_replace_color,
             self.task_move,
             self.task_grow,
+            self.task_color_grow,
             self.task_frame,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)