Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 04:01:29 +0000 (06:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 04:01:29 +0000 (06:01 +0200)
grids.py
quiz_machine.py

index 25bbc80..f6129e9 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -139,7 +139,7 @@ class Grids(problem.Problem):
 
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         if torch.is_tensor(quizzes):
-            return self.reconfigure([quizzes])[0]
+            return self.reconfigure([quizzes], struct=struct)[0]
 
         S = self.height * self.width
         result = [x.new(x.size()) for x in quizzes]
index 8f14fa0..4615e3a 100755 (executable)
@@ -131,7 +131,7 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
-        self.configurations = [
+        self.train_struct = [
             ("A", "f_A", "B", "f_B"),  # The standard order
             ("f_A", "A", "f_B", "B"),  # The reverse order for validation
             ("f_B", "f_A", "A", "B"),  # The synthesis order
@@ -183,8 +183,12 @@ class QuizMachine:
 
     ######################################################################
 
+    def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+        assert struct in self.train_struct
+        return self.problem.make_ar_mask(quizzes, struct, mask)
+
     def predict(self, model, quizzes, struct, mask):
-        ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+        ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
         seq_logproba = torch.empty(quizzes.size(0), device=self.device)
@@ -250,7 +254,7 @@ class QuizMachine:
         predicted_parts = predicted_parts[:128]
         correct_parts = correct_parts[:128]
 
-        self.problem.reconfigure(
+        result, predicted_parts, correct_parts = self.problem.reconfigure(
             [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B")
         )
 
@@ -281,11 +285,11 @@ class QuizMachine:
         model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, configurations=self.configurations
+            model.train_w_quizzes, configurations=self.train_struct
         )
 
         self.randomize_configuations_inplace(
-            model.test_w_quizzes, configurations=self.configurations
+            model.test_w_quizzes, configurations=self.train_struct
         )
 
     ######################################################################
@@ -318,7 +322,7 @@ class QuizMachine:
             )
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, configurations=self.configurations
+            model.train_w_quizzes, configurations=self.train_struct
         )
 
     ######################################################################
@@ -356,7 +360,7 @@ class QuizMachine:
                     c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
                 ):
                     input = input.to(self.device)
-                    ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
+                    ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123")
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
@@ -397,7 +401,7 @@ class QuizMachine:
             # A, f(A), B | f(B)
             result = c_quizzes.clone()
 
-            ar_mask = self.problem.make_ar_mask(
+            ar_mask = self.make_ar_mask(
                 result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
             )
 
@@ -418,7 +422,7 @@ class QuizMachine:
             # f(A), A, f(B) | B
             result = reversed_c_quizzes.clone()
 
-            ar_mask = self.problem.make_ar_mask(
+            ar_mask = self.make_ar_mask(
                 result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
             )
 
@@ -462,7 +466,7 @@ class QuizMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(
+            ar_mask=self.make_ar_mask(
                 c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
             ),
             seq_logproba=seq_logproba,
@@ -475,7 +479,7 @@ class QuizMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(
+            ar_mask=self.make_ar_mask(
                 c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
             ),
             seq_logproba=seq_logproba,
@@ -490,7 +494,7 @@ class QuizMachine:
             model=model_for_generation,
             batch_size=self.batch_size,
             input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(
+            ar_mask=self.make_ar_mask(
                 c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
             ),
             seq_logproba=seq_logproba,