Update.
[culture.git] / quizz_machine.py
index 806dde7..6f7492d 100755 (executable)
@@ -386,8 +386,11 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
-        temperature = 10
+        if reverse_cleanup:
+            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+            temperature = 10.0
+        else:
+            temperature = 1.0
 
         # warnings.warn("noise injection", RuntimeWarning)
         # noise_std = torch.rand(1).item()