Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 12:04:08 +0000 (14:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 12:04:08 +0000 (14:04 +0200)
grids.py

index 131f85c..eaba99a 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -122,12 +122,20 @@ class Grids(problem.Problem):
         S = self.height * self.width
 
         return (
-            (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0]))
-            & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1]))
-            & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2]))
-            & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3]))
+            (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
+            & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
+            & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
+            & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
         ).all()
 
+    def get_structure(self, quizzes):
+        S = self.height * self.width
+        struct = tuple(
+            self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0]
+        )
+        self.check_structure(quizzes, struct)
+        return struct
+
     def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
         assert check_structure(quizzes, struct)
 
@@ -141,24 +149,20 @@ class Grids(problem.Problem):
 
         return ar_mask
 
-    def reconfigure(
-        self,
-        quizzes,
-        struct_from=("A", "f_A", "B", "f_B"),
-        struct_to=("f_B", "A", "f_A", "B"),
-    ):
-        assert check_structure(quizzes, struct_from)
+    def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
 
+        struct_from = self.get_structure(quizzes)
         sf = dict((l, n) for n, l in enumerate(struct_from))
 
         result = quizzes.new(quizzes.size())
-        q = quizzes.reshape(-1, 4, 4 * (S + 1))
-        r = reshape.reshape(-1, 4, 4 * (S + 1))
+        q = quizzes.reshape(-1, 4, S + 1)
+        r = result.reshape(-1, 4, S + 1)
 
-        r[:, 0, :] = q[:, sf[struct_to[0]]]
-        r[:, 1, :] = q[:, sf[struct_to[1]]]
-        r[:, 2, :] = q[:, sf[struct_to[2]]]
-        r[:, 3, :] = q[:, sf[struct_to[3]]]
+        r[:, 0] = q[:, sf[struct[0]], :]
+        r[:, 1] = q[:, sf[struct[1]], :]
+        r[:, 2] = q[:, sf[struct[2]], :]
+        r[:, 3] = q[:, sf[struct[3]], :]
 
         return result
 
@@ -175,6 +179,7 @@ class Grids(problem.Problem):
         self.token_f_A = self.token_A + 1
         self.token_B = self.token_f_A + 1
         self.token_f_B = self.token_B + 1
+
         self.l2tok = {
             "A": self.token_A,
             "f_A": self.token_f_A,
@@ -182,6 +187,13 @@ class Grids(problem.Problem):
             "f_B": self.token_f_B,
         }
 
+        self.tok2l = {
+            self.token_A: "A",
+            self.token_f_A: "f_A",
+            self.token_B: "B",
+            self.token_f_B: "f_B",
+        }
+
         self.nb_token_values = self.token_f_B + 1
 
         self.height = 10
@@ -302,10 +314,10 @@ class Grids(problem.Problem):
                     + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
 
-        img_A = self.add_frame(img_A, colors[:, 0], thickness=6)
-        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6)
-        img_B = self.add_frame(img_B, colors[:, 2], thickness=6)
-        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6)
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
 
         img_A = self.add_frame(img_A, white[None, :], thickness=2)
         img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
@@ -1377,6 +1389,12 @@ class Grids(problem.Problem):
                 total=prompts.size(0),
             )
 
+        quizzes[...] = 0
+        quizzes[:, 0 * (S + 1)] = self.token_A
+        quizzes[:, 1 * (S + 1)] = self.token_f_A
+        quizzes[:, 2 * (S + 1)] = self.token_B
+        quizzes[:, 3 * (S + 1)] = self.token_f_B
+
         for quiz in quizzes:
             q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
             q[...] = 0
@@ -1404,6 +1422,13 @@ if __name__ == "__main__":
     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
     grids = Grids()
 
+    nb = 5
+    quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
+    print(grids.get_structure(quizzes))
+    blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+    print(grids.get_structure(blah))
+    exit(0)
+
     # nb = 1000
     # grids = problem.MultiThreadProblem(
     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
@@ -1424,9 +1449,14 @@ if __name__ == "__main__":
         # for t in [grids.task_symbols]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+        print(grids.get_structure(quizzes))
         predicted_parts = quizzes.new_zeros(quizzes.size(0), 4)
-        predicted_parts[:, 3] = 1
+        predicted_parts[:, 3] = torch.randint(
+            2, (quizzes.size(0),), device=quizzes.device
+        )
+        predicted_parts[:, :3] = 1 - predicted_parts[:, 3:]
         correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device)
+        correct_parts[:, 1:2] = correct_parts[:, :1]
         grids.save_quizzes_as_image(
             "/tmp",
             t.__name__ + ".png",