sum_logits += c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += c_quizzes.size(0)
 
-        nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
+        nb_correct = quizz_machine.compute_correctness(
+            c_quizzes, models, both_direction=True
+        )
 
         if args.dirty_debug:
             nb_correct = torch.randint(
 
 
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
-    def compute_correctness(self, c_quizzes, models_for_validation):
+    def compute_correctness(
+        self, c_quizzes, models_for_validation, both_direction=True
+    ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reversed_result = reversed_c_quizzes.clone()
+            if both_direction:
+                reversed_result = reversed_c_quizzes.clone()
 
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=reversed_result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                # progress_bar_desc="solving reversed c_quizzes",
-                device=self.device,
-            )
+                masked_inplace_autoregression(
+                    model=model,
+                    batch_size=self.batch_size,
+                    input=reversed_result,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba,
+                    temperature=1.0,
+                    deterministic_synthesis=True,
+                    # progress_bar_desc="solving reversed c_quizzes",
+                    device=self.device,
+                )
 
-            reversed_correct = (
-                (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
-            )
+                reversed_correct = (
+                    (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+                )
+
+                correct *= reversed_correct
+
+            # endif
 
-            nb_correct += correct * reversed_correct
+            nb_correct += correct
 
         return nb_correct