Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 08:35:29 +0000 (10:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 08:35:29 +0000 (10:35 +0200)
grids.py

index 8d144cf..aa21543 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -67,26 +67,31 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
         M = F.conv2d(Z[:, None, :, :], w, padding=1)
         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
         M = ((M[:, 0] == 0) & (Z == 0)).long()
+        Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
         M = M * torch.rand(M.size())
         M = M.flatten(1)
         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
-        U += M
+        U += M * Q
 
     for _ in range(nb_iterations):
         M = F.conv2d(Z[:, None, :, :], w, padding=1)
         M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
         M = ((M[:, 1] >= 0) & (Z == 0)).long()
+        Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
         M = M * torch.rand(M.size())
         M = M.flatten(1)
         M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
         U = Z.flatten(1)
-        U += M
+        U += M * Q
 
     M = Z.clone()
     Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
 
-    for _ in range(100):
+    while True:
+        W = Z.clone()
         Z = F.max_pool2d(Z, 3, 1, 1) * M
+        if Z.equal(W):
+            break
 
     Z = Z.long()
     U = Z.flatten(1)
@@ -609,61 +614,50 @@ class Grids(problem.Problem):
         return no, nq, nq_diag
 
     def task_count(self, A, f_A, B, f_B):
-        N = torch.randint(4, (1,)).item() + 2
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
-
-        for X, f_X in [(A, f_A), (B, f_B)]:
-            l_q = torch.randperm(self.height * self.width)[
-                : self.height * self.width // 20
-            ]
-            l_d = torch.randint(N, l_q.size())
-            nb = torch.zeros(N, dtype=torch.int64)
-
-            for q, e in zip(l_q, l_d):
-                d = c[e]
-                i, j = q % self.height, q // self.height
-                if (
-                    nb[e] < self.width
-                    and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
-                ).all():
-                    X[i, j] = d
-                    nb[e] += 1
-
-            l_q = torch.randperm((self.height - 2) * (self.width - 2))[
-                : self.height * self.width // 2
-            ]
-            l_d = torch.randint(N, l_q.size())
-            for q, e in zip(l_q, l_d):
-                d = c[e]
-                i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
-                a1, a2, a3 = X[i - 1, j - 1 : j + 2]
-                a8, a4 = X[i, j - 1], X[i, j + 1]
-                a7, a6, a5 = X[i + 1, j - 1 : j + 2]
-                if (
-                    X[i, j] == 0
-                    and nb[e] < self.width
-                    and (a2 == 0 or a2 == d)
-                    and (a4 == 0 or a4 == d)
-                    and (a6 == 0 or a6 == d)
-                    and (a8 == 0 or a8 == d)
-                    and (a1 == 0 or a2 == d or a8 == d)
-                    and (a3 == 0 or a4 == d or a2 == d)
-                    and (a5 == 0 or a6 == d or a4 == d)
-                    and (a7 == 0 or a8 == d or a6 == d)
-                ):
-                    o = (
-                        (a2 != 0).long()
-                        + (a4 != 0).long()
-                        + (a6 != 0).long()
-                        + (a8 != 0).long()
+        while True:
+            error = False
+
+            N = torch.randint(5, (1,)).item() + 1
+            c = torch.zeros(N + 1)
+            c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+            for X, f_X in [(A, f_A), (B, f_B)]:
+                if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
+                    self.cache_count = list(
+                        grow_islands(
+                            1000,
+                            self.height,
+                            self.width,
+                            nb_seeds=self.height * self.width // 9,
+                            nb_iterations=self.height * self.width // 20,
+                        )
                     )
-                    if o <= 1:
-                        X[i, j] = d
-                        nb[e] += 1 - o
 
-            for e in range(N):
-                for j in range(nb[e]):
-                    f_X[e, j] = c[e]
+                X[...] = self.cache_count.pop()
+
+                k = (X.max() + 1 + (c.size(0) - 1)).item()
+                V = torch.arange(k) // (c.size(0) - 1)
+                V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+                    c.size(0) - 1
+                ) + 1
+                V[0] = 0
+                X[...] = c[V[X]]
+
+                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
+                    f_X[...] = 0
+                    for e in range(1, N + 1):
+                        for j in range((X == c[e]).sum() + 1):
+                            if j < self.width:
+                                f_X[e - 1, j] = c[e]
+                            else:
+                                error = True
+                                break
+                else:
+                    error = True
+                    break
+
+            if not error:
+                break
 
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
@@ -1214,12 +1208,12 @@ if __name__ == "__main__":
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
     # for t in grids.all_tasks:
-    for t in [grids.task_islands]:
+    for t in [grids.task_count]:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time