######################################################################
-def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
+def evaluate_quizzes(quizzes, models, local_device):
nb_correct, nb_wrong = 0, 0
for model in models:
result = predict_full(
model=model,
input=quizzes,
- with_perturbations=with_perturbations,
+ with_perturbations=True,
local_device=local_device,
)
- nb_mistakes = (result != quizzes).long().sum(dim=1)
nb_correct += (nb_mistakes == 0).long()
+ result = predict_full(
+ model=model,
+ input=quizzes,
+ with_perturbations=False,
+ local_device=local_device,
+ )
+ nb_mistakes = (result != quizzes).long().sum(dim=1)
nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
######################################################################
-def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device):
- nb_removed = 0
- for input in c_quizzes.split(args.eval_batch_size):
- _, nb_correct, nb_wrong = evaluate_quizzes(
- quizzes=input,
- models=models,
- with_perturbations=False,
- local_device=local_device,
- )
-
- to_remove = nb_wrong > 0
- nb_removed += to_remove.long().sum()
-
- if nb_removed >= nb_to_remove:
- break
-
-
-######################################################################
-
-
def identity_quizzes(quizzes):
quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
- with_perturbations=True,
local_device=local_device,
)
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
- with_perturbations=False,
local_device=local_device,
)