From 10c44b2a38c49a0353de04da171148480a868ade Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 26 Jun 2024 17:36:33 +0200 Subject: [PATCH] Update. --- main.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index b7b55b5..d063423 100755 --- a/main.py +++ b/main.py @@ -364,10 +364,18 @@ def create_c_quizzes( model_indexes = [] sum_logits, sum_nb_c_quizzes = 0, 0 - while ( - sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) - < nb_for_train + nb_for_test - ): + def nb_generated(): + return sum([sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]) + + def nb_validated(): + return sum( + [ + sum([x.size(0) for x in recorded[n]]) + for n in range(args.nb_correct_to_validate, len(models)) + ] + ) + + while nb_validated() < nb_for_train + nb_for_test: nb_to_validate = nb_for_train + nb_for_test if len(model_indexes) == 0: @@ -396,13 +404,8 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) - nb_generated = sum( - [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()] - ) - log_string( - f"keep c_quizzes {nb_validated*100/nb_generated:.02f}% kept total {nb_validated}/{nb_to_validate}" + f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_validate}" ) # concatenate and shuffle @@ -414,13 +417,21 @@ def create_c_quizzes( else: del recorded[n] - new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test] + new_c_quizzes = torch.cat( + [recorded[n] for n in range(args.nb_correct_to_validate, len(models))], dim=0 + ) + + new_c_quizzes = new_c_quizzes[ + torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[ + : nb_for_train + nb_for_test + ] + ] quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) for n in recorded.keys(): - s = "_validated" if n == args.nb_correct_to_validate else "" + s = "_validated" if n >= args.nb_correct_to_validate and n < len(models) else "" quizz_machine.problem.save_quizzes( recorded[n][:72], args.result_dir, -- 2.39.5