X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=b7b55b5670aa3fe3da6923711c600488bd5f1e3f;hb=e5efa329be244007e11013af84be1f448a04e1a0;hp=d0de5afd5b687a463ed9945604297ec712f97240;hpb=2186d96fccfc525884f1b3fb722c40642891ab0a;p=culture.git diff --git a/main.py b/main.py index d0de5af..b7b55b5 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--nb_correct_to_validate", type=int, default=4) + parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -361,10 +363,9 @@ def create_c_quizzes( model_indexes = [] sum_logits, sum_nb_c_quizzes = 0, 0 - nb_correct_to_validate = len(models) - 1 while ( - sum([x.size(0) for x in recorded[nb_correct_to_validate]]) + sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) < nb_for_train + nb_for_test ): nb_to_validate = nb_for_train + nb_for_test @@ -395,7 +396,7 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nb_validated = sum([x.size(0) for x in recorded[nb_correct_to_validate]]) + nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]]) nb_generated = sum( [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()] ) @@ -413,13 +414,13 @@ def create_c_quizzes( else: del recorded[n] - new_c_quizzes = recorded[nb_correct_to_validate][: nb_for_train + nb_for_test] + new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test] quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) for n in recorded.keys(): - s = "_validated" if n == nb_correct_to_validate else "" + s = "_validated" if n == args.nb_correct_to_validate else "" quizz_machine.problem.save_quizzes( recorded[n][:72], args.result_dir,