Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 30 Jun 2024 10:10:56 +0000 (13:10 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 30 Jun 2024 10:10:56 +0000 (13:10 +0300)
quizz_machine.py

index c5870d0..18d0e0b 100755 (executable)
@@ -379,13 +379,14 @@ class QuizzMachine:
         )
 
         ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
-        ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
+        ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
         ar_mask_solve = 1 - ar_mask_prompt
-        seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
+        seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
         # bracketing of the temperature to get the target logproba
 
-        temperature = 1
+        warnings.warn("high temperature!", RuntimeWarning)
+        temperature = 2
         d_temperature = 1 / 3
 
         while True: