From e5efa329be244007e11013af84be1f448a04e1a0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 26 Jun 2024 15:46:03 +0200 Subject: [PATCH] Update. --- main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index d0de5af..b7b55b5 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--nb_correct_to_validate", type=int, default=4) + parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -361,10 +363,9 @@ def create_c_quizzes( model_indexes = [] sum_logits, sum_nb_c_quizzes = 0, 0 - nb_correct_to_validate = len(models) - 1 while ( - sum([x.size(0) for x in recorded[nb_correct_to_validate]]) + sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) < nb_for_train + nb_for_test ): nb_to_validate = nb_for_train + nb_for_test @@ -395,7 +396,7 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nb_validated = sum([x.size(0) for x in recorded[nb_correct_to_validate]]) + nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) nb_generated = sum( [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()] ) @@ -413,13 +414,13 @@ def create_c_quizzes( else: del recorded[n] - new_c_quizzes = recorded[nb_correct_to_validate][: nb_for_train + nb_for_test] + new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test] quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) for n in recorded.keys(): - s = "_validated" if n == nb_correct_to_validate else "" + s = "_validated" if n == args.nb_correct_to_validate else "" quizz_machine.problem.save_quizzes( recorded[n][:72], args.result_dir, -- 2.39.5