Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 16:25:54 +0000 (19:25 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 16:25:54 +0000 (19:25 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index eb0ef27..0a7be99 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -417,7 +417,9 @@ def create_c_quizzes(
         sum_logits += c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += c_quizzes.size(0)
 
-        nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
+        nb_correct = quizz_machine.compute_correctness(
+            c_quizzes, models, both_direction=True
+        )
 
         if args.dirty_debug:
             nb_correct = torch.randint(
index d591d79..198d279 100755 (executable)
@@ -286,7 +286,9 @@ class QuizzMachine:
 
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
-    def compute_correctness(self, c_quizzes, models_for_validation):
+    def compute_correctness(
+        self, c_quizzes, models_for_validation, both_direction=True
+    ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)
@@ -313,25 +315,30 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reversed_result = reversed_c_quizzes.clone()
+            if both_direction:
+                reversed_result = reversed_c_quizzes.clone()
 
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=reversed_result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                # progress_bar_desc="solving reversed c_quizzes",
-                device=self.device,
-            )
+                masked_inplace_autoregression(
+                    model=model,
+                    batch_size=self.batch_size,
+                    input=reversed_result,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba,
+                    temperature=1.0,
+                    deterministic_synthesis=True,
+                    # progress_bar_desc="solving reversed c_quizzes",
+                    device=self.device,
+                )
 
-            reversed_correct = (
-                (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
-            )
+                reversed_correct = (
+                    (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+                )
+
+                correct *= reversed_correct
+
+            # endif
 
-            nb_correct += correct * reversed_correct
+            nb_correct += correct
 
         return nb_correct