Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 15:06:20 +0000 (18:06 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 15:06:20 +0000 (18:06 +0300)
lang.py

diff --git a/lang.py b/lang.py
index ce159e2..5adf50f 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -185,29 +185,45 @@ class Lang(problem.Problem):
 
     def task_replace_color(self, A, f_A, B, f_B):
         c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
-        i1, j1, i2, j2 = self.rec_coo(A)
-        A[i1:i2, j1:j2] = c1
-        f_A[i1:i2, j1:j2] = c2
-        for _ in range(3):
-            i1, j1, i2, j2 = self.rec_coo(B)
-            B[i1:i2, j1:j2] = c1
-            f_B[i1:i2, j1:j2] = c2
-
-    def move_color(self, A, f_A, B, f_B):
-        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
-
-        i1, j1, i2, j2 = self.rec_coo(A)
-        A[i1:i2, j1:j2] = c1
-        f_A[i1:i2, j1:j2] = c1
-
-        while True:
-            i1, j1, i2, j2 = self.rec_coo(A)
-            if i2 < self.height - 1:
-                break
-        A[i1:i2, j1:j2] = c2
-        f_A[i1:i2, j1:j2] = c2
+        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+            for _ in range(torch.randint(n, (1,)) + 1):
+                i1, j1, i2, j2 = self.rec_coo(X)
+                X[i1:i2, j1:j2] = c1
+                f_X[i1:i2, j1:j2] = c2
+
+    def task_move(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:1] + 1
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+            for _ in range(torch.randint(n, (1,)) + 1):
+                while True:
+                    i1, j1, i2, j2 = self.rec_coo(X)
+                    if (
+                        i1 + di >= 0
+                        and i2 + di < X.size(0)
+                        and j1 + dj >= 0
+                        and j2 + dj < X.size(1)
+                    ):
+                        break
+
+                X[i1:i2, j1:j2] = c
+                f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
+
+    def task_grow(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:1] + 1
+
+        for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+            for _ in range(torch.randint(n, (1,)) + 1):
+                while True:
+                    i1, j1, i2, j2 = self.rec_coo(X)
+                    if i1 + 3 < i2 and j1 + 3 < j2:
+                        break
+
+                X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+                f_X[i1:i2, j1:j2] = c
 
     def generate_prompts_and_answers(self, nb):
+        tasks = [self.task_replace_color, self.task_move, self.task_grow]
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
@@ -216,8 +232,7 @@ class Lang(problem.Problem):
             f_A = prompt[:, 1 * w : 2 * w]
             B = prompt[:, 2 * w : 3 * w]
             f_B = answer
-            # self.task_replace_color(A, f_A, B, f_B)
-            self.move_color(A, f_A, B, f_B)
+            tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B)
         return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(
@@ -246,13 +261,16 @@ if __name__ == "__main__":
 
     lang = Lang(nb_iterations=4)
 
-    prompts, answers = lang.generate_prompts_and_answers(24)
+    prompts, answers = lang.generate_prompts_and_answers(36)
 
     predicted_prompts = torch.rand(prompts.size(0)) < 0.5
     predicted_answers = torch.logical_not(predicted_prompts)
 
     lang.save_quizzes(
-        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
+        "/tmp",
+        "test",
+        prompts,
+        answers,  # predicted_prompts, predicted_answers
     )
 
     # start_time = time.perf_counter()