Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 22:20:13 +0000 (01:20 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 22:20:13 +0000 (01:20 +0300)
grids.py

index ed72099..2d1293c 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -614,6 +614,60 @@ class Grids(problem.Problem):
 
             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
 
+    def task_ortho(self, A, f_A, B, f_B):
+        nb_rec = 3
+        di, dj = torch.randint(3, (2,)) - 1
+        o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
+        m = torch.eye(2)
+        for _ in range(torch.randint(4, (1,))):
+            m = m @ o
+        if torch.rand(1) < 0.5:
+            m[0, :] = -m[0, :]
+
+        ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+                for r in range(nb_rec):
+                    while True:
+                        i1, i2 = torch.randint(self.height - 2, (2,)) + 1
+                        j1, j2 = torch.randint(self.width - 2, (2,)) + 1
+                        if (
+                            i2 >= i1
+                            and j2 >= j1
+                            and max(i2 - i1, j2 - j1) >= 2
+                            and min(i2 - i1, j2 - j1) <= 3
+                        ):
+                            break
+                    X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+                    i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
+
+                    i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
+                    i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
+
+                    i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
+                    i1, i2 = i1.long() + di, i2.long() + di
+                    j1, j2 = j1.long() + dj, j2.long() + dj
+                    if i1 > i2:
+                        i1, i2 = i2, i1
+                    if j1 > j2:
+                        j1, j2 = j2, j1
+
+                    f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+                n = F.one_hot(X.flatten()).sum(dim=0)[1:]
+                if (
+                    n.sum() > self.height * self.width // 4
+                    and (n > 0).long().sum() == nb_rec
+                ):
+                    break
+
     def task_islands(self, A, f_A, B, f_B):
         pass
 
@@ -703,7 +757,7 @@ if __name__ == "__main__":
     grids = Grids()
 
     # for t in grids.all_tasks():
-    for t in [grids.task_islands]:
+    for t in [grids.task_ortho]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)