Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 09:06:31 +0000 (12:06 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 09:06:31 +0000 (12:06 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index fd8ab41..67c57c0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -437,7 +437,8 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+        nv = " ".join([str(x.item()) for x in nv])
 
         log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
 
index 806dde7..6f7492d 100755 (executable)
@@ -386,8 +386,11 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
-        temperature = 10
+        if reverse_cleanup:
+            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+            temperature = 10.0
+        else:
+            temperature = 1.0
 
         # warnings.warn("noise injection", RuntimeWarning)
         # noise_std = torch.rand(1).item()