parser.add_argument("--nb_gpts", type=int, default=5)
+parser.add_argument("--nb_correct_to_validate", type=int, default=4)
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
######################################################################
model_indexes = []
sum_logits, sum_nb_c_quizzes = 0, 0
- nb_correct_to_validate = len(models) - 1
while (
- sum([x.size(0) for x in recorded[nb_correct_to_validate]])
+ sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
< nb_for_train + nb_for_test
):
nb_to_validate = nb_for_train + nb_for_test
for n in range(nb_correct.max() + 1):
recorded[n].append(new_c_quizzes[nb_correct == n].clone())
- nb_validated = sum([x.size(0) for x in recorded[nb_correct_to_validate]])
+ nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
nb_generated = sum(
[sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
)
else:
del recorded[n]
- new_c_quizzes = recorded[nb_correct_to_validate][: nb_for_train + nb_for_test]
+ new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test]
quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
for n in recorded.keys():
- s = "_validated" if n == nb_correct_to_validate else ""
+ s = "_validated" if n == args.nb_correct_to_validate else ""
quizz_machine.problem.save_quizzes(
recorded[n][:72],
args.result_dir,