Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 09:10:03 +0000 (11:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 09:10:03 +0000 (11:10 +0200)
grids.py
quiz_machine.py

index 296c23a..ebc2b0e 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -342,8 +342,21 @@ class Grids(problem.Problem):
     ):
         quizzes = quizzes.to("cpu")
 
-        if not self.check_structure(quizzes, ("A", "f_A", "B", "f_B")):
-            print(f"**WARNING** {filename} is not in A/f_A/B/f_B order")
+        to_reconfigure = [result]
+        if predicted_parts is not None:
+            to_reconfigure.append(predicted_parts)
+        if correct_parts is not None:
+            to_reconfigure.append(correct_parts)
+
+        to_reconfigure = self.problem.reconfigure(
+            to_reconfigure, ("A", "f_A", "B", "f_B")
+        )
+
+        result = to_reconfigure.pop(0)
+        if predicted_parts is not None:
+            predicted_parts = to_reconfigure.pop(0)
+        if correct_parts is not None:
+            correct_parts = to_reconfigure.pop(0)
 
         S = self.height * self.width
 
index 9ca84b3..1ff23ed 100755 (executable)
@@ -79,13 +79,12 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
-        self.train_struct = [
-            ("A", "f_A", "B", "f_B"),  # The standard order
-            ("f_A", "A", "f_B", "B"),  # The reverse order for validation
-            ("B", "f_B", "A", "f_A"),
-            ("f_B", "B", "f_A", "A"),
-            ("f_B", "f_A", "A", "B"),  # The synthesis order
-            ("f_B", "f_A", "A", "B"),  # twice!
+        self.understood_structures = [
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
         ]
 
         self.LOCK_C_QUIZZES = threading.Lock()
@@ -172,7 +171,9 @@ class QuizMachine:
                 quizzes = w_quizzes.clone()
                 from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
 
-            self.randomize_configuations_inplace(quizzes, structs=self.train_struct)
+            self.randomize_configuations_inplace(
+                quizzes, structs=[s for s in self.understood_structures]
+            )
 
             i = torch.randperm(quizzes.size(0), device=quizzes.device)
 
@@ -181,7 +182,7 @@ class QuizMachine:
     ######################################################################
 
     def make_ar_mask(self, quizzes, struct, mask):
-        assert struct in self.train_struct
+        assert struct in [s for s in self.understood_structures]
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
@@ -215,13 +216,7 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask in [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
-            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
-        ]:
+        for struct, mask in self.understood_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i] = self.predict(
@@ -249,10 +244,6 @@ class QuizMachine:
         predicted_parts = predicted_parts[:128]
         correct_parts = correct_parts[:128]
 
-        result, predicted_parts, correct_parts = self.problem.reconfigure(
-            [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B")
-        )
-
         self.problem.save_quizzes_as_image(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
@@ -305,7 +296,7 @@ class QuizMachine:
         )
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, structs=self.train_struct
+            model.train_w_quizzes, structs=[s for s in self.understood_structures]
         )
 
     ######################################################################