Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 14:40:18 +0000 (17:40 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 14:40:18 +0000 (17:40 +0300)
lang.py

diff --git a/lang.py b/lang.py
index d53386c..3d939bb 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -73,6 +73,7 @@ class Lang(problem.Problem):
             predicted_answers = 255
 
         def add_frame(x, c, margin, bottom=False):
+            print(f"{type(x)=} {type(c)=}")
             if bottom:
                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
             else:
@@ -89,7 +90,7 @@ class Lang(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([0, 0, 0], device=c.device) + (
+                c = c * torch.tensor([192, 192, 192], device=c.device) + (
                     1 - c
                 ) * torch.tensor([255, 255, 255], device=c.device)
                 y[...] = c[:, :, None, None]
@@ -98,44 +99,66 @@ class Lang(problem.Problem):
 
             return y
 
-        margin = 4
+        margin = 8
 
-        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
-        h = img_prompts.size(2)
-        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
-        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
-        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
-        img_prompts = add_frame(
-            img_prompts, c=predicted_prompts, margin=margin, bottom=True
+        img_prompts = torch.cat(
+            [
+                add_frame(
+                    add_frame(self.frame2img(x), c=0, margin=1),
+                    c=predicted_prompts,
+                    margin=margin,
+                )
+                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+            ],
+            dim=3,
         )
+
+        h = img_prompts.size(2)
         img_answers = add_frame(
-            img_answers, c=predicted_answers, margin=margin, bottom=True
+            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
+            c=predicted_answers,
+            margin=margin,
         )
 
-        marker_size = 16
+        separator_size = 2 * margin
 
         separator = img_prompts.new_full(
             (
                 img_prompts.size(0),
                 img_prompts.size(1),
                 img_prompts.size(2),
-                marker_size,
+                separator_size,
             ),
             255,
         )
 
-        separator[:, :, 0] = 0
-        separator[:, :, h - 1] = 0
-
-        for k in range(1, 2 * marker_size - 8):
-            i = k - (marker_size - 4)
-            j = marker_size - 5 - abs(i)
-            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
-            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+        marker = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                separator_size,
+            ),
+            255,
+        )
 
-        img = torch.cat([img_prompts, separator, img_answers], dim=3)
+        # marker[:, :, 0] = 0
+        # marker[:, :, h - 1] = 0
+
+        for k in range(1, 2 * separator_size - 8):
+            i = k - (separator_size - 4)
+            j = separator_size - 5 - abs(i)
+            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
+            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+        img = torch.cat(
+            [
+                img_prompts,
+                marker,
+                img_answers,
+            ],
+            dim=3,
+        )
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
@@ -207,11 +230,11 @@ if __name__ == "__main__":
 
     prompts, answers = lang.generate_prompts_and_answers(24)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    # predicted_answers = torch.rand(answers.size(0)) < 0.5
+    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+    predicted_answers = torch.logical_not(predicted_prompts)
 
     lang.save_quizzes(
-        "/tmp", "test", prompts, answers  # , predicted_prompts, predicted_answers
+        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
     )
 
     # start_time = time.perf_counter()