From 0ea3d8f7895969007d76563e59b639811a6347ef Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 10:40:13 +0200 Subject: [PATCH] Update. --- main.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 77dcd2f..e38cbc0 100755 --- a/main.py +++ b/main.py @@ -677,11 +677,11 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): model = copy.deepcopy(model).to(local_device).eval() result = predict_full( model=model, - input=c_quizzes, + input=quizzes, fraction_with_hints=fraction_with_hints, local_device=local_device, ) - nb_mistakes = (result != c_quizzes).long().sum(dim=1) + nb_mistakes = (result != quizzes).long().sum(dim=1) nb_correct += (nb_mistakes == 0).long() nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong @@ -768,7 +768,10 @@ def save_quiz_image( c_quizzes = c_quizzes.to(local_device) to_keep, nb_correct, nb_wrong = evaluate_quizzes( - quizzes=c_quizzes, models=models, local_device=local_device + quizzes=c_quizzes, + models=models, + fraction_with_hints=0, + local_device=local_device, ) if solvable_only: -- 2.39.5