Update.
[culture.git] / reasoning.py
index 57e8056..54a4203 100755 (executable)
@@ -27,9 +27,9 @@ class Reasoning(problem.Problem):
         ("cyan", [0, 255, 255]),
         ("violet", [255, 0, 255]),
         ("lightgreen", [192, 255, 192]),
         ("cyan", [0, 255, 255]),
         ("violet", [255, 0, 255]),
         ("lightgreen", [192, 255, 192]),
-        ("pink", [255, 192, 192]),
+        ("brown", [165, 42, 42]),
         ("lightblue", [192, 192, 255]),
         ("lightblue", [192, 192, 255]),
-        ("gray", [192, 192, 192]),
+        ("gray", [128, 128, 128]),
     ]
 
     def __init__(self, device=torch.device("cpu")):
     ]
 
     def __init__(self, device=torch.device("cpu")):
@@ -42,6 +42,31 @@ class Reasoning(problem.Problem):
     ######################################################################
 
     def frame2img(self, x, scale=15):
     ######################################################################
 
     def frame2img(self, x, scale=15):
+        x = x.reshape(x.size(0), self.height, -1)
+        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+        x = self.colors[x * m].permute(0, 3, 1, 2)
+        s = x.shape
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+        x = x[:, :, 1:, 1:]
+
+        for n in range(m.size(0)):
+            for i in range(m.size(1)):
+                for j in range(m.size(2)):
+                    if m[n, i, j] == 0:
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
+
+        return x
+
+    def frame2img_(self, x, scale=15):
         x = x.reshape(x.size(0), self.height, -1)
         x = self.colors[x].permute(0, 3, 1, 2)
         s = x.shape
         x = x.reshape(x.size(0), self.height, -1)
         x = self.colors[x].permute(0, 3, 1, 2)
         s = x.shape
@@ -173,14 +198,13 @@ class Reasoning(problem.Problem):
     # non-overlapping rectangles quickly, but made the generation of
     # 100k samples go from 1h50 with a lame pure python code to 3min30s
     # with this one.
     # non-overlapping rectangles quickly, but made the generation of
     # 100k samples go from 1h50 with a lame pure python code to 3min30s
     # with this one.
-    def rec_coo(self, x, n, min_height=3, min_width=3):
-        K = 3
-        N = 200
+    def rec_coo(self, nb_rec, min_height=3, min_width=3):
+        nb_trials = 200
 
         while True:
             v = (
                 (
 
         while True:
             v = (
                 (
-                    torch.rand(N * K, self.height + 1, device=self.device)
+                    torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
                     .sort(dim=-1)
                     .indices
                     < 2
                     .sort(dim=-1)
                     .indices
                     < 2
@@ -192,7 +216,7 @@ class Reasoning(problem.Problem):
 
             h = (
                 (
 
             h = (
                 (
-                    torch.rand(N * K, self.width + 1, device=self.device)
+                    torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
                     .sort(dim=-1)
                     .indices
                     < 2
                     .sort(dim=-1)
                     .indices
                     < 2
@@ -207,10 +231,10 @@ class Reasoning(problem.Problem):
             )
 
             v, h = v[i], h[i]
             )
 
             v, h = v[i], h[i]
-            v = v[: v.size(0) - v.size(0) % K]
-            h = h[: h.size(0) - h.size(0) % K]
-            v = v.reshape(v.size(0) // K, K, -1)
-            h = h.reshape(h.size(0) // K, K, -1)
+            v = v[: v.size(0) - v.size(0) % nb_rec]
+            h = h[: h.size(0) - h.size(0) % nb_rec]
+            v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
+            h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
 
             r = v[:, :, :, None] * h[:, :, None, :]
 
 
             r = v[:, :, :, None] * h[:, :, None, :]
 
@@ -260,23 +284,23 @@ class Reasoning(problem.Problem):
     ######################################################################
 
     def task_replace_color(self, A, f_A, B, f_B):
     ######################################################################
 
     def task_replace_color(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
-    def task_move(self, A, f_A, B, f_B):
-        di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+    def task_translate(self, A, f_A, B, f_B):
+        di, dj = torch.randint(3, (2,)) - 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(X, N)
-                i1, j1, i2, j2 = r[N - 1]
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
                 if (
                     i1 + di >= 0
                     and i2 + di < X.size(0)
                 if (
                     i1 + di >= 0
                     and i2 + di < X.size(0)
@@ -285,29 +309,29 @@ class Reasoning(problem.Problem):
                 ):
                     break
 
                 ):
                     break
 
-            for n in range(N):
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
                 else:
                     f_X[i1:i2, j1:j2] = c[n]
 
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
                     f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
                 else:
                     f_X[i1:i2, j1:j2] = c[n]
 
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         direction = torch.randint(2, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
         direction = torch.randint(2, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(X, N)
-                i1, j1, i2, j2 = r[N - 1]
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
                 if i1 + 3 < i2 and j1 + 3 < j2:
                     break
 
                 if i1 + 3 < i2 and j1 + 3 < j2:
                     break
 
-            for n in range(N):
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 i1, j1, i2, j2 = r[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     if direction == 0:
                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
                         f_X[i1:i2, j1:j2] = c[n]
                     if direction == 0:
                         X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
                         f_X[i1:i2, j1:j2] = c[n]
@@ -320,12 +344,12 @@ class Reasoning(problem.Problem):
 
     def task_color_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
 
     def task_color_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
         direction = torch.randint(4, (1,))
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[2 * n]
                 f_X[i1:i2, j1:j2] = c[2 * n]
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[2 * n]
                 f_X[i1:i2, j1:j2] = c[2 * n]
@@ -333,64 +357,80 @@ class Reasoning(problem.Problem):
                 if direction == 0:
                     i = (i1 + i2) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
                 if direction == 0:
                     i = (i1 + i2) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i:i2, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 1:
                     i = (i1 + i2 - 1) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
                         f_X[i:i2, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 1:
                     i = (i1 + i2 - 1) // 2
                     X[i : i + 1, j1:j2] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 2:
                     j = (j1 + j2) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
                         f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
                     else:
                         f_X[i : i + 1, j1:j2] = c[2 * n + 1]
                 elif direction == 2:
                     j = (j1 + j2) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1:i2, j:j2] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
                 elif direction == 3:
                     j = (j1 + j2 - 1) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
                         f_X[i1:i2, j:j2] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
                 elif direction == 3:
                     j = (j1 + j2 - 1) // 2
                     X[i1:i2, j : j + 1] = c[2 * n + 1]
-                    if n == N - 1:
+                    if n == nb_rec - 1:
                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
 
     def task_frame(self, A, f_A, B, f_B):
                         f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
 
     def task_frame(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n]
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n]
-                if n == N - 1:
+                if n == nb_rec - 1:
                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
     def task_detect(self, A, f_A, B, f_B):
                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
     def task_detect(self, A, f_A, B, f_B):
-        N = 3
-        c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(X, N)
-            for n in range(N):
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                f_X[i1, j1] = c[-1]
+                if n < nb_rec - 1:
+                    f_X[i1, j1] = c[-1]
+
+    def task_count(self, A, f_A, B, f_B):
+        N = torch.randint(3, (1,)) + 1
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb = torch.randint(self.width, (3,)) + 1
+            k = torch.randperm(self.height * self.width)[: nb.sum()]
+            p = 0
+            for n in range(N):
+                for m in range(nb[n]):
+                    i, j = k[p] % self.height, k[p] // self.height
+                    X[i, j] = c[n]
+                    f_X[n, m] = c[n]
+                    p += 1
 
     ######################################################################
 
     def generate_prompts_and_answers(self, nb, device="cpu"):
         tasks = [
             self.task_replace_color,
 
     ######################################################################
 
     def generate_prompts_and_answers(self, nb, device="cpu"):
         tasks = [
             self.task_replace_color,
-            self.task_move,
+            self.task_translate,
             self.task_grow,
             self.task_color_grow,
             self.task_frame,
             self.task_detect,
             self.task_grow,
             self.task_color_grow,
             self.task_frame,
             self.task_detect,
+            self.task_count,
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
@@ -442,14 +482,15 @@ if __name__ == "__main__":
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.logical_not(predicted_prompts)
+    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+    predicted_answers = torch.logical_not(predicted_prompts)
 
     reasoning.save_quizzes(
         "/tmp",
         "test",
 
     reasoning.save_quizzes(
         "/tmp",
         "test",
-        prompts[:36],
-        answers[:36],
+        prompts[:64],
+        answers[:64],
         # You can add a bool to put a frame around the predicted parts
         # You can add a bool to put a frame around the predicted parts
-        # predicted_prompts, predicted_answers
+        # predicted_prompts[:64],
+        # predicted_answers[:64],
     )
     )