Merge branch 'dev'
[culture.git] / quiz_machine.py
index 26a0d8b..f0fb408 100755 (executable)
@@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    temperature=1.0,
-    deterministic_synthesis=False,
+    temperature,
+    deterministic_synthesis,
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
@@ -123,9 +123,11 @@ class QuizMachine:
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
         quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-        return not self.problem.trivial_prompts_and_answers(
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
         )
 
     def reverse_time(self, quizzes):