parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--proba_hint", type=float, default=0.05)
+parser.add_argument("--proba_hint", type=float, default=0.25)
parser.add_argument("--quizzes", type=str, default=None)
######################################################################
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+ return (
+ (prediction != quizzes)
+ .long()
+ .reshape(quizzes.size(0), 4, -1)
+ .sum(dim=2)
+ .max(dim=1)
+ .values
+ )
+
+
def evaluate_quizzes(quizzes, models, local_device):
nb_correct, nb_wrong = 0, 0
with_perturbations=True,
local_device=local_device,
)
- nb_mistakes = (result != quizzes).long().sum(dim=1)
- nb_correct += (nb_mistakes == 0).long()
+
+ nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long()
result = predict_full(
model=model,
with_perturbations=False,
local_device=local_device,
)
- nb_mistakes = (result != quizzes).long().sum(dim=1)
- nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
+
+ nb_wrong += (
+ max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong
+ ).long()
to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
nb_wrong >= args.nb_have_to_be_wrong