Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:30:59 +0000 (10:30 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:30:59 +0000 (10:30 +0200)
main.py

diff --git a/main.py b/main.py
index 195afa8..77dcd2f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -592,7 +592,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
     )
-    result = predict_full(modelquizzes, local_device=local_device)
+    result = predict_full(model=model, input=quizzes, local_device=local_device)
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
     )
@@ -670,14 +670,14 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def evaluate_quizzes(c_quizzes, models, fraction_with_hints, local_device):
+def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
     nb_correct, nb_wrong = 0, 0
 
     for model in models:
         model = copy.deepcopy(model).to(local_device).eval()
         result = predict_full(
             model=model,
-            quizzes=c_quizzes,
+            input=c_quizzes,
             fraction_with_hints=fraction_with_hints,
             local_device=local_device,
         )
@@ -767,7 +767,9 @@ def save_quiz_image(
 ):
     c_quizzes = c_quizzes.to(local_device)
 
-    to_keep, nb_correct, nb_wrong = evaluate_quizzes(c_quizzes, models, local_device)
+    to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+        quizzes=c_quizzes, models=models, local_device=local_device
+    )
 
     if solvable_only:
         c_quizzes = c_quizzes[to_keep]