From 527c2d679484a2979206b1d9463f9bdf8f393c49 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 16:43:06 +0200 Subject: [PATCH] Update. --- main.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 7c8c836..85f2cb6 100755 --- a/main.py +++ b/main.py @@ -95,7 +95,7 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_prompt_noise", type=float, default=0.05) -parser.add_argument("--proba_hint", type=float, default=0.05) +parser.add_argument("--proba_hint", type=float, default=0.25) parser.add_argument("--quizzes", type=str, default=None) @@ -620,6 +620,17 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ###################################################################### +def max_nb_mistakes_on_one_grid(quizzes, prediction): + return ( + (prediction != quizzes) + .long() + .reshape(quizzes.size(0), 4, -1) + .sum(dim=2) + .max(dim=1) + .values + ) + + def evaluate_quizzes(quizzes, models, local_device): nb_correct, nb_wrong = 0, 0 @@ -631,8 +642,8 @@ def evaluate_quizzes(quizzes, models, local_device): with_perturbations=True, local_device=local_device, ) - nb_mistakes = (result != quizzes).long().sum(dim=1) - nb_correct += (nb_mistakes == 0).long() + + nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long() result = predict_full( model=model, @@ -640,8 +651,10 @@ def evaluate_quizzes(quizzes, models, local_device): with_perturbations=False, local_device=local_device, ) - nb_mistakes = (result != quizzes).long().sum(dim=1) - nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong + + nb_wrong += ( + max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong + ).long() to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( nb_wrong >= args.nb_have_to_be_wrong -- 2.39.5