Update.
[culture.git] / reasoning.py
index 2874adc..9462f87 100755 (executable)
@@ -17,7 +17,7 @@ from torch.nn import functional as F
 import problem
 
 
-class Reasoning(problem.Problem):
+class Grids(problem.Problem):
     named_colors = [
         ("white", [255, 255, 255]),
         ("red", [255, 0, 0]),
@@ -87,9 +87,16 @@ class Reasoning(problem.Problem):
         answers,
         predicted_prompts=None,
         predicted_answers=None,
+        nrow=4,
     ):
-        prompts = prompts.reshape(prompts.size(0), self.height, -1)
-        answers = answers.reshape(answers.size(0), self.height, -1)
+        S = self.height * self.width
+        As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
+        f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
+            -1, self.height, self.width
+        )
+        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
+        prompts = torch.cat([As, f_As, Bs], dim=2)
+        answers = answers.reshape(answers.size(0), self.height, self.width)
 
         if predicted_prompts is None:
             predicted_prompts = 255
@@ -114,9 +121,13 @@ class Reasoning(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([192, 192, 192], device=c.device) + (
-                    1 - c
-                ) * torch.tensor([255, 255, 255], device=c.device)
+                c = (
+                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
+                    * torch.tensor([64, 64, 64], device=c.device)
+                    + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
@@ -186,7 +197,11 @@ class Reasoning(problem.Problem):
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
-            img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
+            img.float() / 255.0,
+            image_name,
+            nrow=nrow,
+            padding=margin * 4,
+            pad_value=1.0,
         )
 
     ######################################################################
@@ -293,7 +308,7 @@ class Reasoning(problem.Problem):
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
-    def task_move(self, A, f_A, B, f_B):
+    def task_translate(self, A, f_A, B, f_B):
         di, dj = torch.randint(3, (2,)) - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
@@ -406,20 +421,234 @@ class Reasoning(problem.Problem):
                 if n < nb_rec - 1:
                     f_X[i1, j1] = c[-1]
 
+    def contact(self, X, i, j, q):
+        nq, nq_diag = 0, 0
+        no = 0
+
+        for ii, jj in [
+            (i - 1, j - 1),
+            (i - 1, j),
+            (i - 1, j + 1),
+            (i, j - 1),
+            (i, j + 1),
+            (i + 1, j - 1),
+            (i + 1, j),
+            (i + 1, j + 1),
+        ]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] != 0 and X[ii, jj] != q:
+                    no += 1
+
+        for ii, jj in [
+            (i - 1, j - 1),
+            (i - 1, j + 1),
+            (i + 1, j - 1),
+            (i + 1, j + 1),
+        ]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
+                    nq_diag += 1
+
+        for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] == q:
+                    nq += 1
+
+        return no, nq, nq_diag
+
+    def task_count(self, A, f_A, B, f_B):
+        N = torch.randint(4, (1,)) + 2
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb = torch.zeros(N, dtype=torch.int64)
+            q = torch.randint(N, (self.height * self.width,))
+            k = torch.randperm(self.height * self.width)
+            for p in range(self.height * self.width):
+                i, j = k[p] % self.height, k[p] // self.height
+                no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
+                if no == 0 and nq_diag == 0:
+                    if nq == 0:
+                        if nb[q[p]] < self.width:
+                            X[i, j] = c[q[p]]
+                            nb[q[p]] += 1
+                    if nq == 1:
+                        X[i, j] = c[q[p]]
+
+            for n in range(N):
+                for j in range(nb[n]):
+                    f_X[n, j] = c[n]
+
+    def task_trajectory(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                di, dj = torch.randint(7, (2,)) - 3
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                if (
+                    abs(di) + abs(dj) > 0
+                    and i + 2 * di >= 0
+                    and i + 2 * di < self.height
+                    and j + 2 * dj >= 0
+                    and j + 2 * dj < self.width
+                ):
+                    break
+
+            k = 0
+            while (
+                i + k * di >= 0
+                and i + k * di < self.height
+                and j + k * dj >= 0
+                and j + k * dj < self.width
+            ):
+                if k < 2:
+                    X[i + k * di, j + k * dj] = c[k]
+                f_X[i + k * di, j + k * dj] = c[min(k, 1)]
+                k += 1
+
+    def task_bounce(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+
+            def free(i, j):
+                return (
+                    i >= 0
+                    and i < self.height
+                    and j >= 0
+                    and j < self.width
+                    and f_X[i, j] == 0
+                )
+
+            while True:
+                f_X[...] = 0
+                X[...] = 0
+
+                for _ in range((self.height * self.width) // 10):
+                    i, j = torch.randint(self.height, (1,)), torch.randint(
+                        self.width, (1,)
+                    )
+                    X[i, j] = c[0]
+                    f_X[i, j] = c[0]
+
+                while True:
+                    di, dj = torch.randint(7, (2,)) - 3
+                    if abs(di) + abs(dj) == 1:
+                        break
+
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+
+                X[i, j] = c[1]
+                f_X[i, j] = c[1]
+                l = 0
+
+                while True:
+                    l += 1
+                    if free(i + di, j + dj):
+                        pass
+                    elif free(i - dj, j + di):
+                        di, dj = -dj, di
+                        if free(i + dj, j - di):
+                            if torch.rand(1) < 0.5:
+                                di, dj = -di, -dj
+                    elif free(i + dj, j - di):
+                        di, dj = dj, -di
+                    else:
+                        break
+
+                    i, j = i + di, j + dj
+                    f_X[i, j] = c[2]
+                    if l <= 1:
+                        X[i, j] = c[2]
+
+                    if l >= self.width:
+                        break
+
+                f_X[i, j] = c[1]
+                X[i, j] = c[1]
+
+                if l > 3:
+                    break
+
+    def task_scale(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+        i, j = torch.randint(self.height // 2, (1,)), torch.randint(
+            self.width // 2, (1,)
+        )
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            for _ in range(3):
+                while True:
+                    i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
+                        break
+                X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
+                f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
+
+            X[i, j] = c[1]
+            f_X[0:2, 0:2] = c[1]
+
+    def task_islands(self, A, f_A, B, f_B):
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                if (
+                    i == 0
+                    or i == self.height - 1
+                    or j == 0
+                    or j == self.width - 1
+                    or X[i, j] == 1
+                ):
+                    break
+            while True:
+                di, dj = torch.randint(3, (2,)) - 1
+                if abs(di) + abs(dj) > 0:
+                    break
+            X[i, j] = 1
+            while True:
+                i, j = i + di, j + dj
+                if i < 0 or i >= self.height or j < 0 or j >= self.width:
+                    break
+                b = (
+                    i == 0
+                    or i == self.height - 1
+                    or j == 0
+                    or j == self.width - 1
+                    or X[i, j] == 1
+                )
+                X[i, j] = 1
+                if b:
+                    break
+
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb, device="cpu"):
-        tasks = [
+    def all_tasks(self):
+        return [
             self.task_replace_color,
-            self.task_move,
+            self.task_translate,
             self.task_grow,
             self.task_color_grow,
             self.task_frame,
             self.task_detect,
+            self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
+            self.task_scale,
+            # self.task_islands,
         ]
-        prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
-        answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
-        w = self.width
+
+    def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
+        if tasks is None:
+            tasks = self.all_tasks()
+
+        S = self.height * self.width
+        prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
+        answers = torch.zeros(nb, S, dtype=torch.int64)
 
         for prompt, answer in tqdm.tqdm(
             zip(prompts, answers),
@@ -427,10 +656,10 @@ class Reasoning(problem.Problem):
             desc="world generation",
             total=prompts.size(0),
         ):
-            A = prompt[:, 0 * w : 1 * w]
-            f_A = prompt[:, 1 * w : 2 * w]
-            B = prompt[:, 2 * w : 3 * w]
-            f_B = answer
+            A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
+            f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
+            B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
+            f_B = answer.view(self.height, self.width)
             task = tasks[torch.randint(len(tasks), (1,))]
             task(A, f_A, B, f_B)
 
@@ -444,6 +673,7 @@ class Reasoning(problem.Problem):
         answers,
         predicted_prompts=None,
         predicted_answers=None,
+        nrow=4,
     ):
         self.save_image(
             result_dir,
@@ -452,6 +682,7 @@ class Reasoning(problem.Problem):
             answers,
             predicted_prompts,
             predicted_answers,
+            nrow,
         )
 
 
@@ -460,22 +691,35 @@ class Reasoning(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    reasoning = Reasoning()
+    nb = 48
+
+    grids = Grids()
+
+    for t in grids.all_tasks():
+        # for t in [grids.task_islands]:
+        print(t.__name__)
+        prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+
+    exit(0)
+
+    nb = 72
 
     start_time = time.perf_counter()
-    prompts, answers = reasoning.generate_prompts_and_answers(100)
+    prompts, answers = grids.generate_prompts_and_answers(nb)
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.logical_not(predicted_prompts)
+    m = torch.randint(2, (prompts.size(0),))
+    predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
+    predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    reasoning.save_quizzes(
+    grids.save_quizzes(
         "/tmp",
         "test",
-        prompts[:64],
-        answers[:64],
+        prompts[:nb],
+        answers[:nb],
         # You can add a bool to put a frame around the predicted parts
-        predicted_prompts[:64],
-        predicted_answers[:64],
+        predicted_prompts[:nb],
+        predicted_answers[:nb],
     )