Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 09:12:25 +0000 (11:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 09:12:25 +0000 (11:12 +0200)
grids.py

index ba0131d..5778f85 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -214,14 +214,17 @@ class Grids(problem.Problem):
             self.task_half_fill,
             self.task_frame,
             self.task_detect,
-            # self.task_count, # NOT REVERSIBLE
-            self.task_trajectory,
-            self.task_bounce,
             self.task_scale,
             self.task_symbols,
-            self.task_isometry,
             self.task_corners,
             self.task_contact,
+            self.task_path,
+            self.task_fill,
+            ############################################ hard ones
+            self.task_isometry,
+            self.task_trajectory,
+            self.task_bounce,
+            # self.task_count, # NOT REVERSIBLE
             # self.task_islands, # TOO MESSY
         ]
 
@@ -467,13 +470,13 @@ class Grids(problem.Problem):
                         j[:, 1, 0],
                         j[:, 1, 1],
                     )
-                    no_overlap = torch.logical_not(
+                    no_overlap = (
                         (A_i1 >= B_i2)
-                        & (A_i2 <= B_i1)
-                        & (A_j1 >= B_j1)
-                        & (A_j2 <= B_j1)
+                        | (A_i2 <= B_i1)
+                        | (A_j1 >= B_j2)
+                        | (A_j2 <= B_j1)
                     )
-                    i, j = i[no_overlap], j[no_overlap]
+                    i, j = (i[no_overlap], j[no_overlap])
                 elif nb_rec == 3:
                     A_i1, A_i2, A_j1, A_j2 = (
                         i[:, 0, 0],
@@ -1322,6 +1325,100 @@ class Grids(problem.Problem):
                     X[i2 - 1, j1] = c[n]
                 f_X[i1:i2, j1:j2] = c[n]
 
+    def compdist(self, X, i, j):
+        dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+        d = dd[1:-1, 1:-1]
+        m = (X > 0).long()
+        d[i, j] = 0
+        e = d.clone()
+        while True:
+            e[...] = d
+            d[...] = (
+                d.min(dd[:-2, 1:-1] + 1)
+                .min(dd[2:, 1:-1] + 1)
+                .min(dd[1:-1, :-2] + 1)
+                .min(dd[1:-1, 2:] + 1)
+            )
+            d[...] = (1 - m) * d + m * self.height * self.width
+            if e.equal(d):
+                break
+
+        return d
+
+    # @torch.compile
+    def task_path(self, A, f_A, B, f_B):
+        nb_rec = 2
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                while True:
+                    i1, i2 = torch.randint(self.height, (2,))
+                    j1, j2 = torch.randint(self.width, (2,))
+                    if (
+                        abs(i1 - i2) + abs(j1 - j2) > 2
+                        and X[i1, j1] == 0
+                        and X[i2, j2] == 0
+                    ):
+                        break
+
+                d2 = self.compdist(X, i2, j2)
+                d = self.compdist(X, i1, j1)
+
+                if d2[i1, j1] < 2 * self.width:
+                    break
+
+            m = ((d + d2) == d[i2, j2]).long()
+            f_X[...] = m * c[-1] + (1 - m) * f_X
+
+            X[i1, j1] = c[-2]
+            X[i2, j2] = c[-2]
+            f_X[i1, j1] = c[-2]
+            f_X[i2, j2] = c[-2]
+
+    # @torch.compile
+    def task_fill(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            accept_full = torch.rand(1) < 0.5
+
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                while True:
+                    i, j = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    if X[i, j] == 0:
+                        break
+
+                d = self.compdist(X, i, j)
+                m = (d < self.height * self.width).long()
+                X[i, j] = c[-1]
+                f_X[...] = m * c[-1] + (1 - m) * f_X
+
+                if accept_full or (d * (X == 0)).max() == self.height * self.width:
+                    break
+
+    # end_tasks
+
     ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
@@ -1418,11 +1515,11 @@ if __name__ == "__main__":
     # exit(0)
 
     # if True:
-    nb, nrow = 8, 2
+    nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_count]:
+    for t in [grids.task_fill]:
         # for t in [grids.task_symbols]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])