From b078ef92ab3661672e08519a3b8204b9fde672d1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 15:40:50 +0200 Subject: [PATCH] Update. --- main.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index ef340ea..7bdd09e 100755 --- a/main.py +++ b/main.py @@ -662,7 +662,7 @@ for i in range(args.nb_models): ###################################################################### -def evaluate_quizzes(quizzes, models, with_perturbations, local_device): +def evaluate_quizzes(quizzes, models, local_device): nb_correct, nb_wrong = 0, 0 for model in models: @@ -670,11 +670,17 @@ def evaluate_quizzes(quizzes, models, with_perturbations, local_device): result = predict_full( model=model, input=quizzes, - with_perturbations=with_perturbations, + with_perturbations=True, local_device=local_device, ) - nb_mistakes = (result != quizzes).long().sum(dim=1) nb_correct += (nb_mistakes == 0).long() + result = predict_full( + model=model, + input=quizzes, + with_perturbations=False, + local_device=local_device, + ) + nb_mistakes = (result != quizzes).long().sum(dim=1) nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( @@ -687,26 +693,6 @@ def evaluate_quizzes(quizzes, models, with_perturbations, local_device): ###################################################################### -def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device): - nb_removed = 0 - for input in c_quizzes.split(args.eval_batch_size): - _, nb_correct, nb_wrong = evaluate_quizzes( - quizzes=input, - models=models, - with_perturbations=False, - local_device=local_device, - ) - - to_remove = nb_wrong > 0 - nb_removed += to_remove.long().sum() - - if nb_removed >= nb_to_remove: - break - - -###################################################################### - - def identity_quizzes(quizzes): quizzes = quizzes.reshape(quizzes.size(0), 4, -1) return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & ( @@ -741,7 +727,6 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): to_keep, nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, - with_perturbations=True, local_device=local_device, ) @@ -787,7 +772,6 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): to_keep, nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, - with_perturbations=False, local_device=local_device, ) -- 2.39.5