Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 15:51:04 +0000 (18:51 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 15:51:04 +0000 (18:51 +0300)
lang.py

diff --git a/lang.py b/lang.py
index 5adf50f..4c3c64f 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -183,18 +183,20 @@ class Lang(problem.Problem):
                 break
         return i1, j1, i2, j2
 
+    ######################################################################
+
     def task_replace_color(self, A, f_A, B, f_B):
         c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
-        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
             for _ in range(torch.randint(n, (1,)) + 1):
                 i1, j1, i2, j2 = self.rec_coo(X)
                 X[i1:i2, j1:j2] = c1
                 f_X[i1:i2, j1:j2] = c2
 
     def task_move(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:1] + 1
         di, dj = torch.randint(2, (2,)) * 2 - 1
-        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+            c = torch.randperm(len(self.colors) - 1)[:1] + 1
             for _ in range(torch.randint(n, (1,)) + 1):
                 while True:
                     i1, j1, i2, j2 = self.rec_coo(X)
@@ -210,20 +212,44 @@ class Lang(problem.Problem):
                 f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
 
     def task_grow(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:1] + 1
-
-        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+        direction = torch.randint(2, (1,))
+        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+            c = torch.randperm(len(self.colors) - 1)[:1] + 1
             for _ in range(torch.randint(n, (1,)) + 1):
                 while True:
                     i1, j1, i2, j2 = self.rec_coo(X)
                     if i1 + 3 < i2 and j1 + 3 < j2:
                         break
 
-                X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+                if direction == 0:
+                    X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+                    f_X[i1:i2, j1:j2] = c
+                else:
+                    X[i1:i2, j1:j2] = c
+                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+
+    def task_frame(self, A, f_A, B, f_B):
+        direction = torch.randint(2, (1,))
+        for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+            c = torch.randperm(len(self.colors) - 1)[:1] + 1
+            for _ in range(torch.randint(n, (1,)) + 1):
+                while True:
+                    i1, j1, i2, j2 = self.rec_coo(X)
+                    if i1 + 3 < i2 and j1 + 3 < j2:
+                        break
+                X[i1:i2, j1:j2] = c
                 f_X[i1:i2, j1:j2] = c
+                f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+
+    ######################################################################
 
     def generate_prompts_and_answers(self, nb):
-        tasks = [self.task_replace_color, self.task_move, self.task_grow]
+        tasks = [
+            self.task_replace_color,
+            self.task_move,
+            self.task_grow,
+            self.task_frame,
+        ]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
@@ -263,40 +289,14 @@ if __name__ == "__main__":
 
     prompts, answers = lang.generate_prompts_and_answers(36)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.logical_not(predicted_prompts)
+    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+    predicted_answers = torch.logical_not(predicted_prompts)
 
     lang.save_quizzes(
         "/tmp",
         "test",
         prompts,
-        answers,  # predicted_prompts, predicted_answers
+        answers,
+        # You can add a bool to put a frame around the predicted parts
+        # predicted_prompts, predicted_answers
     )
-
-    # start_time = time.perf_counter()
-    # token_sequences = lang.generate_token_sequences(nb=64)
-    # delay = time.perf_counter() - start_time
-    # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
-    # print(lang.seq2str(seq[:4]))
-
-    # for t in range(len(it[0])):
-    # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
-    # torchvision.utils.save_image(
-    # img.float() / 255.0,
-    # f"/tmp/frame_{t:03d}.png",
-    # nrow=8,
-    # padding=6,
-    # pad_value=0,
-    # )
-
-    # m = (torch.rand(seq.size()) < 0.05).long()
-    # seq = (1 - m) * seq + m * 23
-
-    # print(seq.size())
-    # img = lang.seq2img(token_sequences)
-    # print(img.size())
-
-    # torchvision.utils.save_image(
-    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
-    # )