Update.
[culture.git] / quizz_machine.py
index 5807b66..0d6d8f5 100755 (executable)
@@ -333,7 +333,7 @@ class QuizzMachine:
         )
 
     def compute_correctness(
-        self, c_quizzes, models_for_validation, both_directions=True
+        self, c_quizzes, models_for_validation, both_directions=False
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
@@ -390,7 +390,7 @@ class QuizzMachine:
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
+    def generate_quizzes(self, nb, model_for_generation):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
@@ -403,10 +403,7 @@ class QuizzMachine:
 
         seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
 
-        if reverse_cleanup:
-            temperature = 10.0
-        else:
-            temperature = 1.0
+        temperature = 10.0
 
         # First, we generate the answer at high temperature
 
@@ -433,7 +430,7 @@ class QuizzMachine:
             input=c_quizzes,
             ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
-            temperature=temperature,
+            temperature=1.0,
             deterministic_synthesis=True,
             device=self.device,
         )