Update.
[culture.git] / quizz_machine.py
index d591d79..198d279 100755 (executable)
@@ -286,7 +286,9 @@ class QuizzMachine:
 
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
-    def compute_correctness(self, c_quizzes, models_for_validation):
+    def compute_correctness(
+        self, c_quizzes, models_for_validation, both_direction=True
+    ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)
@@ -313,25 +315,30 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reversed_result = reversed_c_quizzes.clone()
+            if both_direction:
+                reversed_result = reversed_c_quizzes.clone()
 
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=reversed_result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                # progress_bar_desc="solving reversed c_quizzes",
-                device=self.device,
-            )
+                masked_inplace_autoregression(
+                    model=model,
+                    batch_size=self.batch_size,
+                    input=reversed_result,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba,
+                    temperature=1.0,
+                    deterministic_synthesis=True,
+                    # progress_bar_desc="solving reversed c_quizzes",
+                    device=self.device,
+                )
 
-            reversed_correct = (
-                (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
-            )
+                reversed_correct = (
+                    (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+                )
+
+                correct *= reversed_correct
+
+            # endif
 
-            nb_correct += correct * reversed_correct
+            nb_correct += correct
 
         return nb_correct