From 2c32cbc02b3b1a69742b48d5ccd079690df48f3f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 15 Aug 2024 17:03:41 +0200 Subject: [PATCH] Update. --- main.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 78defa6..dcd76ad 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_understands", type=float, default=0.95) -parser.add_argument("--proba_not_understands", type=float, default=0.5) +parser.add_argument("--proba_not_understands", type=float, default=0.1) parser.add_argument("--temperature_hot", type=float, default=1.5) @@ -583,26 +583,26 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): # the most consistent from a model which is confident for s in range(proba_own_solution.size(0)): - dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands - nb_fails = dont_get_this_quiz.long().sum() - if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: - for model in models: - if dont_get_this_quiz[model.id]: - assert proba_own_solution[s, model.id] < args.proba_understands - proba_other_solutions = model_proba_solutions( - model, solved_c_quizzes[s] - ) - - # proba_other_solutions += torch.rand(proba_other_solutions.size()) * 1e-6 - - proba_other_solutions[dont_get_this_quiz] = -1 - # print( - # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n" - # ) - i = proba_other_solutions.argmax() - model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) - teaching_count[i, model.id] += 1 - nb_validated += 1 + if proba_own_solution[s, :].min() < args.proba_not_understands: + dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands + nb_fails = dont_get_this_quiz.long().sum() + if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: + for model in models: + if dont_get_this_quiz[model.id]: + assert ( + proba_own_solution[s, model.id] < args.proba_understands + ) + proba_other_solutions = model_proba_solutions( + model, solved_c_quizzes[s] + ) + proba_other_solutions += ( + torch.rand(proba_other_solutions.size()) * 1e-6 + ) + proba_other_solutions[dont_get_this_quiz] = -1 + i = proba_other_solutions.argmax() + model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) + teaching_count[i, model.id] += 1 + nb_validated += 1 duration = time.perf_counter() - start_time -- 2.39.5