Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jul 2024 21:56:57 +0000 (23:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jul 2024 21:56:57 +0000 (23:56 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index bdfe9fd..0d0d373 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -90,7 +90,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2.5)
+parser.add_argument("--generation_temperature", type=float, default=1.5)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
@@ -419,7 +419,12 @@ def keep_good_quizzes(models, quizzes):
         )
 
     elif args.c_quiz_validation_mode == "predict":
-        to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1)
+        nc = quiz_machine.solution_nb_correct(models, quizzes)
+        count_nc = tuple(
+            n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
+        )
+        log_string(f"nb_correct {count_nc}")
+        to_keep = nc == (len(models) - 1)
 
     else:
         raise ValueError(f"{args.c_quiz_validation_mode=}")
index bbd1b7b..032305a 100755 (executable)
@@ -611,7 +611,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=1,
+                temperature=0.75,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -624,7 +624,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=1,
+                temperature=0.75,
                 deterministic_synthesis=False,
                 device=self.device,
             )