Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 12:20:10 +0000 (15:20 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 12:20:10 +0000 (15:20 +0300)
quizz_machine.py

index 239dc68..bf36d0b 100755 (executable)
@@ -311,7 +311,6 @@ class QuizzMachine:
             self.test_c_quizzes.append(new_c_quizzes)
 
     def comput_correctness(self, c_quizzes, models_for_validation):
-        ###############################################################
         # Create the reverse quizzes
 
         token_forward, token_backward = self.problem.direction_tokens()
@@ -328,11 +327,9 @@ class QuizzMachine:
         ar_mask = self.make_ar_mask(c_quizzes)
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
 
-        ###############################################################
-        # Check how many of the other models can solve them in both
-        # directions
+        # Check how many of models can solve the quizzes in both directions
 
-        nb_correct = []
+        nb_correct = 0
 
         for model in models_for_validation:
             result = c_quizzes.clone()
@@ -369,14 +366,13 @@ class QuizzMachine:
                 (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
             )
 
-            nb_correct.append((correct * reverse_correct)[None, :])
+            nb_correct += correct * reverse_correct
 
-        return torch.cat(nb_correct, dim=0).sum(dim=0)
+        return nb_correct
 
-    def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
-        ###############################################################
-        # Generate quizzes with model
+    ###############################################################
 
+    def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )