Update.
[culture.git] / lang.py
diff --git a/lang.py b/lang.py
index d53386c..ce159e2 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -66,6 +66,9 @@ class Lang(problem.Problem):
         predicted_prompts=None,
         predicted_answers=None,
     ):
+        prompts = prompts.reshape(prompts.size(0), self.height, -1)
+        answers = answers.reshape(answers.size(0), self.height, -1)
+
         if predicted_prompts is None:
             predicted_prompts = 255
 
@@ -89,7 +92,7 @@ class Lang(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([0, 0, 0], device=c.device) + (
+                c = c * torch.tensor([192, 192, 192], device=c.device) + (
                     1 - c
                 ) * torch.tensor([255, 255, 255], device=c.device)
                 y[...] = c[:, :, None, None]
@@ -98,44 +101,66 @@ class Lang(problem.Problem):
 
             return y
 
-        margin = 4
-
-        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
-        h = img_prompts.size(2)
-        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
-        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
-        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+        margin = 8
 
-        img_prompts = add_frame(
-            img_prompts, c=predicted_prompts, margin=margin, bottom=True
+        img_prompts = torch.cat(
+            [
+                add_frame(
+                    add_frame(self.frame2img(x), c=0, margin=1),
+                    c=predicted_prompts,
+                    margin=margin,
+                )
+                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+            ],
+            dim=3,
         )
+
+        h = img_prompts.size(2)
         img_answers = add_frame(
-            img_answers, c=predicted_answers, margin=margin, bottom=True
+            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
+            c=predicted_answers,
+            margin=margin,
         )
 
-        marker_size = 16
+        separator_size = 2 * margin
 
         separator = img_prompts.new_full(
             (
                 img_prompts.size(0),
                 img_prompts.size(1),
                 img_prompts.size(2),
-                marker_size,
+                separator_size,
             ),
             255,
         )
 
-        separator[:, :, 0] = 0
-        separator[:, :, h - 1] = 0
-
-        for k in range(1, 2 * marker_size - 8):
-            i = k - (marker_size - 4)
-            j = marker_size - 5 - abs(i)
-            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
-            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+        marker = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                separator_size,
+            ),
+            255,
+        )
 
-        img = torch.cat([img_prompts, separator, img_answers], dim=3)
+        # marker[:, :, 0] = 0
+        # marker[:, :, h - 1] = 0
+
+        for k in range(1, 2 * separator_size - 8):
+            i = k - (separator_size - 4)
+            j = separator_size - 5 - abs(i)
+            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
+            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+        img = torch.cat(
+            [
+                img_prompts,
+                marker,
+                img_answers,
+            ],
+            dim=3,
+        )
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
@@ -158,26 +183,42 @@ class Lang(problem.Problem):
                 break
         return i1, j1, i2, j2
 
-    def task_red_to_green(self, A, f_A, B, f_B):
+    def task_replace_color(self, A, f_A, B, f_B):
+        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
         i1, j1, i2, j2 = self.rec_coo(A)
-        A[i1:i2, j1:j2] = self.name2color["red"]
-        f_A[i1:i2, j1:j2] = self.name2color["green"]
-        i1, j1, i2, j2 = self.rec_coo(B)
-        B[i1:i2, j1:j2] = self.name2color["red"]
-        f_B[i1:i2, j1:j2] = self.name2color["green"]
+        A[i1:i2, j1:j2] = c1
+        f_A[i1:i2, j1:j2] = c2
+        for _ in range(3):
+            i1, j1, i2, j2 = self.rec_coo(B)
+            B[i1:i2, j1:j2] = c1
+            f_B[i1:i2, j1:j2] = c2
+
+    def move_color(self, A, f_A, B, f_B):
+        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+        i1, j1, i2, j2 = self.rec_coo(A)
+        A[i1:i2, j1:j2] = c1
+        f_A[i1:i2, j1:j2] = c1
+
+        while True:
+            i1, j1, i2, j2 = self.rec_coo(A)
+            if i2 < self.height - 1:
+                break
+        A[i1:i2, j1:j2] = c2
+        f_A[i1:i2, j1:j2] = c2
 
     def generate_prompts_and_answers(self, nb):
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
         for prompt, answer in zip(prompts, answers):
-            self.task_red_to_green(
-                prompt[:, 0 * w : 1 * w],
-                prompt[:, 1 * w : 2 * w],
-                prompt[:, 2 * w : 3 * w],
-                answer,
-            )
-        return prompts, answers
+            A = prompt[:, 0 * w : 1 * w]
+            f_A = prompt[:, 1 * w : 2 * w]
+            B = prompt[:, 2 * w : 3 * w]
+            f_B = answer
+            # self.task_replace_color(A, f_A, B, f_B)
+            self.move_color(A, f_A, B, f_B)
+        return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(
         self,
@@ -207,11 +248,11 @@ if __name__ == "__main__":
 
     prompts, answers = lang.generate_prompts_and_answers(24)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    # predicted_answers = torch.rand(answers.size(0)) < 0.5
+    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+    predicted_answers = torch.logical_not(predicted_prompts)
 
     lang.save_quizzes(
-        "/tmp", "test", prompts, answers  # , predicted_prompts, predicted_answers
+        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
     )
 
     # start_time = time.perf_counter()