From 6d209c8dbbef8e9978f383a3e8095a35bb0deeb3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 08:27:55 +0200 Subject: [PATCH] Update. --- main.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 1cf31b3..c54d701 100755 --- a/main.py +++ b/main.py @@ -91,11 +91,11 @@ parser.add_argument("--max_fail_to_validate", type=int, default=2) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98) -parser.add_argument("--proba_understands", type=float, default=0.9) +parser.add_argument("--proba_understands", type=float, default=0.95) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--temperature_hot", type=float, default=1.25) +parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=1) @@ -454,13 +454,13 @@ def one_epoch(model, quiz_machine, local_device=main_device): def model_transformer_hot(model): - # model.temperature = args.temperature_hot - model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) + model.temperature = args.temperature_hot + # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) def model_transformer_cold(model): - pass - # model.temperature = args.temperature_cold + model.temperature = args.temperature_cold + # pass c_quizzes_procedure = [ @@ -469,11 +469,6 @@ c_quizzes_procedure = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] -c_quizzes_procedure_ = [ - (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), model_transformer_hot), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold), -] - def save_additional_results(models, science_w_quizzes): for model in models: @@ -611,6 +606,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) recorded_validated.append(c_quizzes) + nb_validated = c_quizzes.size(0) + else: + nb_validated = 0 total_nb_validated = nb_validated_per_model.sum().item() @@ -631,10 +629,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 else: e = "???" - nb_validated = ( - recorded_validated[-1].size(0) if len(recorded_validated) > 0 else 0 - ) - log_string( f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) -- 2.20.1