Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:40:13 +0000 (10:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:40:13 +0000 (10:40 +0200)
main.py

diff --git a/main.py b/main.py
index 77dcd2f..e38cbc0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -677,11 +677,11 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
         model = copy.deepcopy(model).to(local_device).eval()
         result = predict_full(
             model=model,
-            input=c_quizzes,
+            input=quizzes,
             fraction_with_hints=fraction_with_hints,
             local_device=local_device,
         )
-        nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+        nb_mistakes = (result != quizzes).long().sum(dim=1)
         nb_correct += (nb_mistakes == 0).long()
         nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
 
@@ -768,7 +768,10 @@ def save_quiz_image(
     c_quizzes = c_quizzes.to(local_device)
 
     to_keep, nb_correct, nb_wrong = evaluate_quizzes(
-        quizzes=c_quizzes, models=models, local_device=local_device
+        quizzes=c_quizzes,
+        models=models,
+        fraction_with_hints=0,
+        local_device=local_device,
     )
 
     if solvable_only: