From: François Fleuret Date: Tue, 24 Sep 2024 15:23:34 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=bb161b255d2d066d978f83227ca2ad79ac54ddec;p=culture.git Update. --- diff --git a/main.py b/main.py index a70c758..de04d5b 100755 --- a/main.py +++ b/main.py @@ -952,9 +952,24 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -# for model in models: -# inject_plasticity(model, args.proba_plasticity) -# model.test_accuracy = 0 +#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +### # The c quizzes used to estimate the test accuracy have to be +### # solvable without hints +### +### nb_correct, _ = evaluate_quizzes( +### quizzes=train_c_quizzes, +### models=models, +### with_hints=False, +### local_device=main_device, +### ) +### nb_correct = nb_correct.to("cpu") +### +### test_c_quizzes = train_c_quizzes[nb_correct >= len(models)//2] +### +### for model in models: +### inject_plasticity(model, args.proba_plasticity) +### model.test_accuracy = 0 +#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() @@ -1014,7 +1029,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) nb_correct = nb_correct.to("cpu") - test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] + test_c_quizzes = train_c_quizzes[nb_correct >= len(models) // 2] for model in models: inject_plasticity(model, args.proba_plasticity)