model_indexes = []
sum_logits, sum_nb_c_quizzes = 0, 0
- while (
- sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
- < nb_for_train + nb_for_test
- ):
+ def nb_generated():
+ return sum([sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()])
+
+ def nb_validated():
+ return sum(
+ [
+ sum([x.size(0) for x in recorded[n]])
+ for n in range(args.nb_correct_to_validate, len(models))
+ ]
+ )
+
+ while nb_validated() < nb_for_train + nb_for_test:
nb_to_validate = nb_for_train + nb_for_test
if len(model_indexes) == 0:
for n in range(nb_correct.max() + 1):
recorded[n].append(new_c_quizzes[nb_correct == n].clone())
- nb_validated = sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
- nb_generated = sum(
- [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
- )
-
log_string(
- f"keep c_quizzes {nb_validated*100/nb_generated:.02f}% kept total {nb_validated}/{nb_to_validate}"
+ f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_validate}"
)
# concatenate and shuffle
else:
del recorded[n]
- new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test]
+ new_c_quizzes = torch.cat(
+ [recorded[n] for n in range(args.nb_correct_to_validate, len(models))], dim=0
+ )
+
+ new_c_quizzes = new_c_quizzes[
+ torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[
+ : nb_for_train + nb_for_test
+ ]
+ ]
quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
for n in recorded.keys():
- s = "_validated" if n == args.nb_correct_to_validate else ""
+ s = "_validated" if n >= args.nb_correct_to_validate and n < len(models) else ""
quizz_machine.problem.save_quizzes(
recorded[n][:72],
args.result_dir,