Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 14:51:03 +0000 (17:51 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 14:51:03 +0000 (17:51 +0300)
lang.py
main.py
quizz_machine.py

diff --git a/lang.py b/lang.py
index 3d939bb..ce159e2 100755 (executable)
--- a/lang.py
+++ b/lang.py
@@ -66,6 +66,9 @@ class Lang(problem.Problem):
         predicted_prompts=None,
         predicted_answers=None,
     ):
+        prompts = prompts.reshape(prompts.size(0), self.height, -1)
+        answers = answers.reshape(answers.size(0), self.height, -1)
+
         if predicted_prompts is None:
             predicted_prompts = 255
 
@@ -73,7 +76,6 @@ class Lang(problem.Problem):
             predicted_answers = 255
 
         def add_frame(x, c, margin, bottom=False):
-            print(f"{type(x)=} {type(c)=}")
             if bottom:
                 h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
             else:
@@ -181,26 +183,42 @@ class Lang(problem.Problem):
                 break
         return i1, j1, i2, j2
 
-    def task_red_to_green(self, A, f_A, B, f_B):
+    def task_replace_color(self, A, f_A, B, f_B):
+        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+        i1, j1, i2, j2 = self.rec_coo(A)
+        A[i1:i2, j1:j2] = c1
+        f_A[i1:i2, j1:j2] = c2
+        for _ in range(3):
+            i1, j1, i2, j2 = self.rec_coo(B)
+            B[i1:i2, j1:j2] = c1
+            f_B[i1:i2, j1:j2] = c2
+
+    def move_color(self, A, f_A, B, f_B):
+        c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+
         i1, j1, i2, j2 = self.rec_coo(A)
-        A[i1:i2, j1:j2] = self.name2color["red"]
-        f_A[i1:i2, j1:j2] = self.name2color["green"]
-        i1, j1, i2, j2 = self.rec_coo(B)
-        B[i1:i2, j1:j2] = self.name2color["red"]
-        f_B[i1:i2, j1:j2] = self.name2color["green"]
+        A[i1:i2, j1:j2] = c1
+        f_A[i1:i2, j1:j2] = c1
+
+        while True:
+            i1, j1, i2, j2 = self.rec_coo(A)
+            if i2 < self.height - 1:
+                break
+        A[i1:i2, j1:j2] = c2
+        f_A[i1:i2, j1:j2] = c2
 
     def generate_prompts_and_answers(self, nb):
         prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
         answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
         w = self.width
         for prompt, answer in zip(prompts, answers):
-            self.task_red_to_green(
-                prompt[:, 0 * w : 1 * w],
-                prompt[:, 1 * w : 2 * w],
-                prompt[:, 2 * w : 3 * w],
-                answer,
-            )
-        return prompts, answers
+            A = prompt[:, 0 * w : 1 * w]
+            f_A = prompt[:, 1 * w : 2 * w]
+            B = prompt[:, 2 * w : 3 * w]
+            f_B = answer
+            # self.task_replace_color(A, f_A, B, f_B)
+            self.move_color(A, f_A, B, f_B)
+        return prompts.flatten(1), answers.flatten(1)
 
     def save_quizzes(
         self,
diff --git a/main.py b/main.py
index a8a6191..b4e7318 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, wireworld, quizz_machine
+import sky, lang, quizz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -249,8 +249,8 @@ if args.problem == "sky":
         nb_iterations=args.sky_nb_iterations,
         speed=args.sky_speed,
     )
-elif args.problem == "wireworld":
-    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+elif args.problem == "lang":
+    problem = lang.Lang(nb_iterations=2)
 else:
     raise ValueError
 
index 90f288e..3828e5b 100755 (executable)
@@ -167,6 +167,8 @@ class QuizzMachine:
     def generate_token_sequences(self, nb):
         prompts, answers = self.problem.generate_prompts_and_answers(nb)
 
+        print(f"{prompts.size()=} {answers.size()=}")
+
         if self.prompt_len is None:
             self.prompt_len = prompts.size(1)