Update.
[culture.git] / grids.py
index 6e9e6c7..20a964b 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -769,9 +769,97 @@ class Grids(problem.Problem):
                 ):
                     break
 
                 ):
                     break
 
+    def compute_distance(self, walls, goal_i, goal_j, start_i, start_j):
+        max_length = walls.numel()
+        dist = torch.full_like(walls, max_length)
+
+        dist[goal_i, goal_j] = 0
+        pred_dist = torch.empty_like(dist)
+
+        while True:
+            pred_dist.copy_(dist)
+            d = (
+                torch.cat(
+                    (
+                        dist[None, 1:-1, 0:-2],
+                        dist[None, 2:, 1:-1],
+                        dist[None, 1:-1, 2:],
+                        dist[None, 0:-2, 1:-1],
+                    ),
+                    0,
+                ).min(dim=0)[0]
+                + 1
+            )
+
+            dist[1:-1, 1:-1].minimum_(d)  # = torch.min(dist[1:-1, 1:-1], d)
+            dist = walls * max_length + (1 - walls) * dist
+
+            if dist[start_i, start_j] < max_length or dist.equal(pred_dist):
+                return dist * (1 - walls)
+
     # @torch.compile
     # @torch.compile
-    def task_islands(self, A, f_A, B, f_B):
-        pass
+    def task_path(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        dist = torch.empty(self.height + 2, self.width + 2)
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb_rec = torch.randint(3, (1,)) + 1
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                X[...] = 0
+                f_X[...] = 0
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[0]
+                    f_X[i1:i2, j1:j2] = c[0]
+                while True:
+                    i0, j0 = torch.randint(self.height, (1,)), torch.randint(
+                        self.width, (1,)
+                    )
+                    if X[i0, j0] == 0:
+                        break
+                while True:
+                    i1, j1 = torch.randint(self.height, (1,)), torch.randint(
+                        self.width, (1,)
+                    )
+                    if X[i1, j1] == 0:
+                        break
+                dist[...] = 1
+                dist[1:-1, 1:-1] = (X != 0).long()
+                dist[...] = self.compute_distance(dist, i1 + 1, j1 + 1, i0 + 1, j0 + 1)
+                if dist[i0 + 1, j0 + 1] >= 1 and dist[i0 + 1, j0 + 1] < self.height * 4:
+                    break
+
+            dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
+            dist[0, :] = self.height * self.width
+            dist[-1, :] = self.height * self.width
+            dist[:, 0] = self.height * self.width
+            dist[:, -1] = self.height * self.width
+            # dist += torch.rand(dist.size())
+
+            i, j = i0 + 1, j0 + 1
+            while i != i1 + 1 or j != j1 + 1:
+                f_X[i - 1, j - 1] = c[2]
+                r, s, t, u = (
+                    dist[i - 1, j],
+                    dist[i, j - 1],
+                    dist[i + 1, j],
+                    dist[i, j + 1],
+                )
+                m = min(r, s, t, u)
+                if r == m:
+                    i = i - 1
+                elif t == m:
+                    i = i + 1
+                elif s == m:
+                    j = j - 1
+                else:
+                    j = j + 1
+
+            X[i0, j0] = c[2]
+            # f_X[i0, j0] = c[1]
+
+            X[i1, j1] = c[1]
+            f_X[i1, j1] = c[1]
 
     # for X, f_X in [(A, f_A), (B, f_B)]:
     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
 
     # for X, f_X in [(A, f_A), (B, f_B)]:
     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
@@ -797,7 +885,7 @@ class Grids(problem.Problem):
             self.task_scale,
             self.task_symbols,
             self.task_ortho,
             self.task_scale,
             self.task_symbols,
             self.task_ortho,
-            # self.task_islands,
+            #            self.task_path,
         ]
 
     def trivial_prompts_and_answers(self, prompts, answers):
         ]
 
     def trivial_prompts_and_answers(self, prompts, answers):
@@ -854,6 +942,15 @@ class Grids(problem.Problem):
             nrow,
         )
 
             nrow,
         )
 
+    def save_some_examples(self, result_dir):
+        nb, nrow = 72, 4
+        for t in self.all_tasks():
+            print(t.__name__)
+            prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
+            self.save_quizzes(
+                result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+            )
+
 
 ######################################################################
 
 
 ######################################################################
 
@@ -875,15 +972,17 @@ if __name__ == "__main__":
     # exit(0)
 
     # if True:
     # exit(0)
 
     # if True:
-    # nb,nrow = 72,4
-    nb, nrow = 8, 2
+    nb, nrow = 72, 4
+    nb, nrow = 8, 2
 
 
-    for t in grids.all_tasks():
-        # for t in [grids.task_replace_color]:
+    for t in grids.all_tasks():
+    for t in [grids.task_path]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
 
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
 
+    # exit(0)
+
     nb = 1000
 
     for t in grids.all_tasks():
     nb = 1000
 
     for t in grids.all_tasks():