From 098b30235186c21bee8d54861515376fa2e96b65 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 3 Sep 2024 15:06:37 +0200 Subject: [PATCH] Update. --- main.py | 52 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index fb0b4df..2376868 100755 --- a/main.py +++ b/main.py @@ -53,7 +53,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None) parser.add_argument("--inference_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=40000) +parser.add_argument("--nb_train_samples", type=int, default=25000) parser.add_argument("--nb_test_samples", type=int, default=1000) @@ -61,7 +61,7 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) -parser.add_argument("--c_quiz_multiplier", type=int, default=4) +parser.add_argument("--c_quiz_multiplier", type=int, default=10) parser.add_argument("--learning_rate", type=float, default=5e-4) @@ -1079,7 +1079,9 @@ def model_ae_proba_solutions(model, input, log_proba=False): mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_prediction(model, q, mask_generate) + targets, logits = targets_and_prediction( + model, q, mask_generate, prompt_noise=args.prompt_noise + ) loss_per_token = F.cross_entropy( logits.transpose(1, 2), targets, reduction="none" ) @@ -1400,7 +1402,7 @@ def save_badness_statistics( log_string(f"wrote {filename}") -def generate_ae_c_quizzes(models, local_device=main_device): +def generate_ae_c_quizzes(models, nb, local_device=main_device): criteria = [ # c_quiz_criterion_only_one, c_quiz_criterion_one_good_one_bad, @@ -1411,8 +1413,8 @@ def generate_ae_c_quizzes(models, local_device=main_device): # c_quiz_criterion_some, ] - for m in models: - m.eval().to(local_device) + # To be thread-safe we must make copies + models = [copy.deepcopy(model).to(local_device) for model in models] quad_order = ("A", "f_A", "B", "f_B") @@ -1426,7 +1428,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - wanted_nb = args.nb_train_samples // args.c_quiz_multiplier + wanted_nb = nb nb_to_save = 256 with torch.autograd.no_grad(): @@ -1515,6 +1517,10 @@ def generate_ae_c_quizzes(models, local_device=main_device): return torch.cat(a, dim=0).unique(dim=0) +def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): + record.append(generate_ae_c_quizzes(models, nb, local_device)) + + ###################################################################### current_epoch = 0 @@ -1600,9 +1606,37 @@ for n_epoch in range(current_epoch, args.nb_epochs): save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after") last_n_epoch_c_quizzes = n_epoch + nb_c_quizzes_to_generate = args.nb_train_samples // args.c_quiz_multiplier + + # -------------------------------------------------------------------- + + records, threads = [], [] + start_time = time.perf_counter() - c_quizzes = generate_ae_c_quizzes(models, local_device=main_device) + + for gpu in gpus: + t = threading.Thread( + target=thread_generate_ae_c_quizzes, + daemon=True, + args=(models, nb_c_quizzes_to_generate, records, gpu), + ) + + # To get a different sequence between threads + log_string(f"dummy {torch.rand(1)}") + threads.append(t) + t.start() + + for t in threads: + t.join() + time_c_quizzes = time.perf_counter() - start_time + + c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0) + + # -------------------------------------------------------------------- + + log_string(f"generated_c_quizzes {c_quizzes.size()=}") + time_train = 0 for model in models: model.test_accuracy = 0 @@ -1614,6 +1648,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): else: log_string(f"nb_c_quizzes {c_quizzes.size(0)}") + # -------------------------------------------------------------------- + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] -- 2.39.5