+ log_string(
+ f"keep c_quizzes {nb_validated*100/nb_generated:.02f}% kept total {nb_validated}/{nb_to_validate}"
+ )
+
+ # concatenate and shuffle
+ for n in recorded.keys():
+ if len(recorded[n]) > 0:
+ q = torch.cat(recorded[n], dim=0)
+ q = q[torch.randperm(q.size(0), device=q.device)]
+ recorded[n] = q
+ else:
+ del recorded[n]
+
+ new_c_quizzes = recorded[args.nb_correct_to_validate][: nb_for_train + nb_for_test]
+
+ quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+ quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+
+ for n in recorded.keys():
+ s = "_validated" if n == args.nb_correct_to_validate else ""
+ quizz_machine.problem.save_quizzes(
+ recorded[n][:72],
+ args.result_dir,
+ f"culture_c_quiz_{n_epoch:04d}_N{n}{s}",
+ )
+
+ return sum_logits / sum_nb_c_quizzes