Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:43:06 +0000 (16:43 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:43:06 +0000 (16:43 +0200)
main.py

diff --git a/main.py b/main.py
index 7c8c836..85f2cb6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -95,7 +95,7 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
 
-parser.add_argument("--proba_hint", type=float, default=0.05)
+parser.add_argument("--proba_hint", type=float, default=0.25)
 
 parser.add_argument("--quizzes", type=str, default=None)
 
@@ -620,6 +620,17 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 ######################################################################
 
 
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+    return (
+        (prediction != quizzes)
+        .long()
+        .reshape(quizzes.size(0), 4, -1)
+        .sum(dim=2)
+        .max(dim=1)
+        .values
+    )
+
+
 def evaluate_quizzes(quizzes, models, local_device):
     nb_correct, nb_wrong = 0, 0
 
@@ -631,8 +642,8 @@ def evaluate_quizzes(quizzes, models, local_device):
             with_perturbations=True,
             local_device=local_device,
         )
-        nb_mistakes = (result != quizzes).long().sum(dim=1)
-        nb_correct += (nb_mistakes == 0).long()
+
+        nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long()
 
         result = predict_full(
             model=model,
@@ -640,8 +651,10 @@ def evaluate_quizzes(quizzes, models, local_device):
             with_perturbations=False,
             local_device=local_device,
         )
-        nb_mistakes = (result != quizzes).long().sum(dim=1)
-        nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
+
+        nb_wrong += (
+            max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong
+        ).long()
 
     to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
         nb_wrong >= args.nb_have_to_be_wrong