From bb161b255d2d066d978f83227ca2ad79ac54ddec Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 24 Sep 2024 17:23:34 +0200 Subject: [PATCH] Update. --- main.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index a70c758..de04d5b 100755 --- a/main.py +++ b/main.py @@ -952,9 +952,24 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -# for model in models: -# inject_plasticity(model, args.proba_plasticity) -# model.test_accuracy = 0 +#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +### # The c quizzes used to estimate the test accuracy have to be +### # solvable without hints +### +### nb_correct, _ = evaluate_quizzes( +### quizzes=train_c_quizzes, +### models=models, +### with_hints=False, +### local_device=main_device, +### ) +### nb_correct = nb_correct.to("cpu") +### +### test_c_quizzes = train_c_quizzes[nb_correct >= len(models)//2] +### +### for model in models: +### inject_plasticity(model, args.proba_plasticity) +### model.test_accuracy = 0 +#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() @@ -1014,7 +1029,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) nb_correct = nb_correct.to("cpu") - test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] + test_c_quizzes = train_c_quizzes[nb_correct >= len(models) // 2] for model in models: inject_plasticity(model, args.proba_plasticity) -- 2.39.5