######################################################################
-def quiz_validation_1(models, c_quizzes, local_device):
- nb_have_to_be_correct = args.nb_models // 2
+def quiz_validation_paris(models, c_quizzes, local_device):
+ nb_have_to_be_correct = 3
nb_have_to_be_wrong = 1
- nb_runs = 1
+ nb_runs = 3
nb_mistakes_to_be_wrong = 5
record_wrong = []
return to_keep, wrong
-def quiz_validation_2(models, c_quizzes, local_device):
- nb_have_to_be_correct = 3
- nb_have_to_be_wrong = 1
- nb_runs = 3
-
- record_wrong = []
- nb_correct, nb_wrong = 0, 0
-
- for i, model in enumerate(models):
- assert i == model.id # a bit of paranoia
- model = copy.deepcopy(model).to(local_device).eval()
- log_probas_max, log_probas_min = None, None
- for _ in range(nb_runs):
- log_probas = model_ae_proba_solutions(
- model, c_quizzes, log_probas=True, reduce=False
- )
- log_probas_max = (
- log_probas
- if log_probas_max is None
- else log_probas.maximum(log_probas_max)
- )
- log_probas_min = (
- log_probas
- if log_probas_min is None
- else log_probas.minimum(log_probas_min)
- )
- probas = log_probas.sum(dim=1).exp()
- correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0
- wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3
- record_wrong.append(wrong[:, None])
- nb_correct += correct.long()
- nb_wrong += wrong.long()
-
- to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
- wrong = torch.cat(record_wrong, dim=1)
-
- return to_keep, wrong
-
-
-def quiz_validation(models, c_quizzes, local_device):
+def quiz_validation_berne(models, c_quizzes, local_device):
nb_have_to_be_correct = 3
nb_have_to_be_wrong = 1
nb_runs = 3
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
+ to_keep, record_wrong = quiz_validation_berne(
+ models, c_quizzes, local_device
+ )
q = c_quizzes[to_keep]
if q.size(0) > 0: