Update.
[culture.git] / grids.py
index 9462f87..659bd6c 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -23,9 +23,9 @@ class Grids(problem.Problem):
         ("red", [255, 0, 0]),
         ("green", [0, 192, 0]),
         ("blue", [0, 0, 255]),
-        ("orange", [255, 192, 0]),
+        ("yellow", [255, 224, 0]),
         ("cyan", [0, 255, 255]),
-        ("violet", [255, 0, 255]),
+        ("violet", [224, 128, 255]),
         ("lightgreen", [192, 255, 192]),
         ("brown", [165, 42, 42]),
         ("lightblue", [192, 192, 255]),
@@ -34,7 +34,6 @@ class Grids(problem.Problem):
 
     def __init__(self, device=torch.device("cpu")):
         self.colors = torch.tensor([c for _, c in self.named_colors])
-        self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
         self.height = 10
         self.width = 10
         self.device = device
@@ -66,19 +65,6 @@ class Grids(problem.Problem):
 
         return x
 
-    def frame2img_(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
-        x = self.colors[x].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
-
-        return x
-
     def save_image(
         self,
         result_dir,
@@ -593,38 +579,52 @@ class Grids(problem.Problem):
             X[i, j] = c[1]
             f_X[0:2, 0:2] = c[1]
 
-    def task_islands(self, A, f_A, B, f_B):
+    def task_symbols(self, A, f_A, B, f_B):
+        nb_rec = 4
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        delta = 3
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
-                if (
-                    i == 0
-                    or i == self.height - 1
-                    or j == 0
-                    or j == self.width - 1
-                    or X[i, j] == 1
-                ):
-                    break
-            while True:
-                di, dj = torch.randint(3, (2,)) - 1
-                if abs(di) + abs(dj) > 0:
-                    break
-            X[i, j] = 1
-            while True:
-                i, j = i + di, j + dj
-                if i < 0 or i >= self.height or j < 0 or j >= self.width:
-                    break
-                b = (
-                    i == 0
-                    or i == self.height - 1
-                    or j == 0
-                    or j == self.width - 1
-                    or X[i, j] == 1
+                i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
+                    self.width - delta + 1, (nb_rec,)
                 )
-                X[i, j] = 1
-                if b:
+                d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
+                d.fill_diagonal_(delta + 1)
+                if d.min() > delta:
                     break
 
+            for k in range(1, nb_rec):
+                X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+
+            ai, aj = i.float().mean(), j.float().mean()
+
+            q = torch.randint(3, (1,)) + 1
+
+            X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+            X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+            X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+            X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+            assert i[q] != ai and j[q] != aj
+
+            X[
+                i[0] + delta // 2 + (i[q] - ai).sign().long(),
+                j[0] + delta // 2 + (j[q] - aj).sign().long(),
+            ] = c[nb_rec]
+
+            f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+    def task_islands(self, A, f_A, B, f_B):
+        pass
+
+    # for X, f_X in [(A, f_A), (B, f_B)]:
+    # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
+    # k = torch.randperm(self.height * self.width)
+    # X[...]=-1
+    # for q in k:
+    # i,j=q%self.height,q//self.height
+    # if
+
     ######################################################################
 
     def all_tasks(self):
@@ -639,9 +639,16 @@ class Grids(problem.Problem):
             self.task_trajectory,
             self.task_bounce,
             self.task_scale,
+            self.task_symbols,
             # self.task_islands,
         ]
 
+    def trivial_prompts_and_answers(self, prompts, answers):
+        S = self.height * self.width
+        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
+        f_Bs = answers
+        return (B_s == f_Bs).long().min(dim=-1).values > 0
+
     def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
         if tasks is None:
             tasks = self.all_tasks()
@@ -695,8 +702,8 @@ if __name__ == "__main__":
 
     grids = Grids()
 
-    for t in grids.all_tasks():
-        # for t in [grids.task_islands]:
+    for t in grids.all_tasks():
+    for t in [grids.task_islands]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)