Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 12:57:30 +0000 (15:57 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 12:57:30 +0000 (15:57 +0300)
lang.py [new file with mode: 0755]

diff --git a/lang.py b/lang.py
new file mode 100755 (executable)
index 0000000..d53386c
--- /dev/null
+++ b/lang.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm, os, warnings
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+import problem
+
+
+class Lang(problem.Problem):
+    named_colors = [
+        ("white", [255, 255, 255]),
+        ("red", [255, 0, 0]),
+        ("green", [0, 192, 0]),
+        ("blue", [0, 0, 255]),
+        ("orange", [255, 192, 0]),
+        ("cyan", [0, 255, 255]),
+        ("violet", [255, 0, 255]),
+        ("lightgreen", [192, 255, 192]),
+        ("pink", [255, 192, 192]),
+        ("lightblue", [192, 192, 255]),
+        ("gray", [192, 192, 192]),
+    ]
+
+    def __init__(
+        self,
+        nb_iterations=2,
+    ):
+        self.colors = torch.tensor([c for _, c in self.named_colors])
+        self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
+        self.height = 10
+        self.width = 10
+        self.nb_iterations = nb_iterations
+
+    ######################################################################
+
+    def frame2img(self, x, scale=15):
+        x = x.reshape(x.size(0), self.height, -1)
+        x = self.colors[x].permute(0, 3, 1, 2)
+        s = x.shape
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+        x = x[:, :, 1:, 1:]
+
+        return x
+
+    def save_image(
+        self,
+        result_dir,
+        filename,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        if predicted_prompts is None:
+            predicted_prompts = 255
+
+        if predicted_answers is None:
+            predicted_answers = 255
+
+        def add_frame(x, c, margin, bottom=False):
+            if bottom:
+                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
+            else:
+                h, w, di, dj = (
+                    x.size(2) + 2 * margin,
+                    x.size(3) + 2 * margin,
+                    margin,
+                    margin,
+                )
+
+            y = x.new_full((x.size(0), x.size(1), h, w), 0)
+
+            if type(c) is int:
+                y[...] = c
+            else:
+                c = c.long()[:, None]
+                c = c * torch.tensor([0, 0, 0], device=c.device) + (
+                    1 - c
+                ) * torch.tensor([255, 255, 255], device=c.device)
+                y[...] = c[:, :, None, None]
+
+            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+
+            return y
+
+        margin = 4
+
+        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+        h = img_prompts.size(2)
+        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
+
+        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
+        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+
+        img_prompts = add_frame(
+            img_prompts, c=predicted_prompts, margin=margin, bottom=True
+        )
+        img_answers = add_frame(
+            img_answers, c=predicted_answers, margin=margin, bottom=True
+        )
+
+        marker_size = 16
+
+        separator = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                marker_size,
+            ),
+            255,
+        )
+
+        separator[:, :, 0] = 0
+        separator[:, :, h - 1] = 0
+
+        for k in range(1, 2 * marker_size - 8):
+            i = k - (marker_size - 4)
+            j = marker_size - 5 - abs(i)
+            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+        img = torch.cat([img_prompts, separator, img_answers], dim=3)
+
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(
+            img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0
+        )
+
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
+
+    def rec_coo(self, x):
+        while True:
+            i1, i2 = torch.randint(x.size(0), (2,))
+            if i1 < i2 - 1:
+                break
+        while True:
+            j1, j2 = torch.randint(x.size(1), (2,))
+            if j1 < j2 - 1:
+                break
+        return i1, j1, i2, j2
+
+    def task_red_to_green(self, A, f_A, B, f_B):
+        i1, j1, i2, j2 = self.rec_coo(A)
+        A[i1:i2, j1:j2] = self.name2color["red"]
+        f_A[i1:i2, j1:j2] = self.name2color["green"]
+        i1, j1, i2, j2 = self.rec_coo(B)
+        B[i1:i2, j1:j2] = self.name2color["red"]
+        f_B[i1:i2, j1:j2] = self.name2color["green"]
+
+    def generate_prompts_and_answers(self, nb):
+        prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
+        answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
+        w = self.width
+        for prompt, answer in zip(prompts, answers):
+            self.task_red_to_green(
+                prompt[:, 0 * w : 1 * w],
+                prompt[:, 1 * w : 2 * w],
+                prompt[:, 2 * w : 3 * w],
+                answer,
+            )
+        return prompts, answers
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        self.save_image(
+            result_dir,
+            filename_prefix + ".png",
+            prompts,
+            answers,
+            predicted_prompts,
+            predicted_answers,
+        )
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import time
+
+    lang = Lang(nb_iterations=4)
+
+    prompts, answers = lang.generate_prompts_and_answers(24)
+
+    # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+    # predicted_answers = torch.rand(answers.size(0)) < 0.5
+
+    lang.save_quizzes(
+        "/tmp", "test", prompts, answers  # , predicted_prompts, predicted_answers
+    )
+
+    # start_time = time.perf_counter()
+    # token_sequences = lang.generate_token_sequences(nb=64)
+    # delay = time.perf_counter() - start_time
+    # print(f"{token_sequences.size(0)/delay:02f} seq/s")
+
+    # print(lang.seq2str(seq[:4]))
+
+    # for t in range(len(it[0])):
+    # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
+    # torchvision.utils.save_image(
+    # img.float() / 255.0,
+    # f"/tmp/frame_{t:03d}.png",
+    # nrow=8,
+    # padding=6,
+    # pad_value=0,
+    # )
+
+    # m = (torch.rand(seq.size()) < 0.05).long()
+    # seq = (1 - m) * seq + m * 23
+
+    # print(seq.size())
+    # img = lang.seq2img(token_sequences)
+    # print(img.size())
+
+    # torchvision.utils.save_image(
+    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
+    # )