Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 18:37:01 +0000 (20:37 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 18:37:01 +0000 (20:37 +0200)
grids.py
quiz_machine.py

index 5ddcf32..80b8b1d 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -139,18 +139,24 @@ class Grids(problem.Problem):
 
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
+        result = quizzes.new(quizzes.size())
+
+        struct_from = self.get_structure(quizzes[:1])
+        i = self.indices_select(quizzes, struct_from)
 
-        struct_from = self.get_structure(quizzes)
         sf = dict((l, n) for n, l in enumerate(struct_from))
 
-        result = quizzes.new(quizzes.size())
-        q = quizzes.reshape(quizzes.size(0), 4, S + 1)
-        r = result.reshape(result.size(0), 4, S + 1)
+        q = quizzes.reshape(-1, 4, S + 1)[i]
+
+        result[i, 0 * (S + 1) : 1 * (S + 1)] = q[:, sf[struct[0]], :]
+        result[i, 1 * (S + 1) : 2 * (S + 1)] = q[:, sf[struct[1]], :]
+        result[i, 2 * (S + 1) : 3 * (S + 1)] = q[:, sf[struct[2]], :]
+        result[i, 3 * (S + 1) : 4 * (S + 1)] = q[:, sf[struct[3]], :]
+
+        j = i == False
 
-        r[:, 0] = q[:, sf[struct[0]], :]
-        r[:, 1] = q[:, sf[struct[1]], :]
-        r[:, 2] = q[:, sf[struct[2]], :]
-        r[:, 3] = q[:, sf[struct[3]], :]
+        if j.any():
+            result[j] = self.reconfigure(quizzes[j], struct=struct)
 
         return result
 
@@ -258,8 +264,8 @@ class Grids(problem.Problem):
         y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
         y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
-        y[:, :, torch.arange(0, y.size(2), scale), :] = 0
+        y[:, :, :, torch.arange(0, y.size(3), scale)] = 224
+        y[:, :, torch.arange(0, y.size(2), scale), :] = 224
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
@@ -1446,9 +1452,12 @@ if __name__ == "__main__":
 
     nb = 5
     quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
+    print(quizzes)
     print(grids.get_structure(quizzes))
     quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+    print("DEBUG2", quizzes)
     print(grids.get_structure(quizzes))
+    print(quizzes)
 
     i = torch.rand(quizzes.size(0)) < 0.5
 
index 2fb196c..a384377 100755 (executable)
@@ -131,6 +131,12 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
+        self.configurations = [
+            ("A", "f_A", "B", "f_B"),  # The standard order
+            ("f_A", "A", "f_B", "B"),  # The reverse order for validation
+            ("f_B", "f_A", "A", "B"),  # The synthesis order
+        ]
+
         self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
         self.test_c_quizzes = []
@@ -212,7 +218,8 @@ class QuizMachine:
         nb = 0
         for struct, mask in [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
-            (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
         ]:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
@@ -220,6 +227,7 @@ class QuizMachine:
                 model=model, quizzes=input[i], struct=struct, mask=mask
             )
 
+        print(f"{nb=} {input.size(0)=}")
         assert nb == input.size(0)
 
         main_test_accuracy = correct.sum() / correct.size(0)
@@ -236,27 +244,29 @@ class QuizMachine:
 
     ######################################################################
 
-    def flip_half_in_place(self, quizzes):
-        r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5
-        i = self.problem.indices_select(
-            quizzes=quizzes, struct=("A", "f_A", "B", "f_B")
-        )
-        quizzes[i & r] = self.problem.reconfigure(
-            quizzes[i & r], struct=("f_B", "f_A", "B", "A")
-        )
-        j = self.problem.indices_select(
-            quizzes=quizzes, struct=("f_B", "f_A", "B", "A")
-        )
-        quizzes[j & r] = self.problem.reconfigure(
-            quizzes[j & r], struct=("A", "f_A", "B", "f_B")
+    def randomize_configuations_inplace(self, quizzes, configurations):
+        r = torch.randint(
+            len(configurations), (quizzes.size(0),), device=quizzes.device
         )
 
+        for c in range(len(configurations)):
+            quizzes[r == c] = self.problem.reconfigure(
+                quizzes[r == c], struct=configurations[c]
+            )
+
     def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
         model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
         model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
 
-        self.flip_half_in_place(model.train_w_quizzes)
-        self.flip_half_in_place(model.test_w_quizzes)
+        self.randomize_configuations_inplace(
+            model.train_w_quizzes, configurations=self.configurations
+        )
+
+        self.randomize_configuations_inplace(
+            model.test_w_quizzes, configurations=self.configurations
+        )
+
+        # print(model.train_w_quizzes.sum())
 
     ######################################################################
 
@@ -287,7 +297,9 @@ class QuizMachine:
                 model.train_w_quizzes.size(0)
             )
 
-        self.flip_half_in_place(model.train_w_quizzes)
+        self.randomize_configuations_inplace(
+            model.train_w_quizzes, configurations=self.configurations
+        )
 
     ######################################################################