Update.
[culture.git] / quizz_machine.py
index 88f2c1c..d591d79 100755 (executable)
@@ -398,4 +398,16 @@ class QuizzMachine:
                 device=self.device,
             )
 
+            c_quizzes = self.reverse_time(c_quizzes)
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                device=self.device,
+            )
+
         return c_quizzes, seq_logproba.mean()