From 49649bb113ce0857bc168dca58b59d583d1d4ae1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 22:26:20 +0200 Subject: [PATCH] Update. --- main.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index cd9ec20..00722d6 100755 --- a/main.py +++ b/main.py @@ -500,8 +500,10 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): q_p, q_g = quizzes.to(local_device).chunk(2) - # Half of the samples train the prediction, and we inject noise in - # all, and hints in half + # Half of the samples train the prediction. We inject noise in all + # to avoid drift of the culture toward "finding waldo" type of + # complexity, and hints in half to allow dealing with hints when + # validating c quizzes b_p = samples_for_prediction_imt(q_p) b_p = add_noise_imt(b_p) half = torch.rand(b_p.size(0)) < 0.5 @@ -673,7 +675,7 @@ def identity_quizzes(quizzes): def generate_c_quizzes(models, nb_to_generate, local_device=main_device): record = [] - nb_validated = 0 + nb_generated, nb_validated = 0, 0 start_time = time.perf_counter() last_log = -1 @@ -689,12 +691,15 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): model=model, nb=args.eval_batch_size * 10, local_device=local_device ) + nb_generated += c_quizzes.size(0) + c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False] if c_quizzes.size(0) > 0: - # Select the ones that are solved properly by some models and - # not understood by others - + # Select the ones that are solved properly by some models + # and not understood by others. We add "hints" to allow + # the current models to deal with functionally more + # complex quizzes than the ones they have been trained on nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, @@ -735,6 +740,9 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): duration = time.perf_counter() - start_time log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h") + log_string( + f"validation_rate {nb_validated} / {nb_generated} ({nb_validated*100/nb_generated:.02e}%)" + ) return torch.cat(record).to("cpu") @@ -962,6 +970,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) train_c_quizzes = train_c_quizzes[-args.nb_train_samples :] + # The c quizzes used to estimate the test accuracy have to be + # solvable without hints nb_correct, _ = evaluate_quizzes( quizzes=train_c_quizzes, models=models, -- 2.39.5