+def create_c_quizzes(
+ models,
+ quizz_machine,
+ nb_for_train=1000,
+ nb_for_test=100,
+ min_ave_seq_logproba=None,
+):
+ # We will store the generated quizzes for each number of
+ # correct prediction
+ recorded = dict([(n, []) for n in range(len(models) + 1)])
+
+ model_indexes = []
+ sum_logits, sum_nb_c_quizzes = 0, 0
+
+ while (
+ sum([x.size(0) for x in recorded[args.nb_correct_to_validate]])
+ < nb_for_train + nb_for_test
+ ):
+ nb_to_validate = nb_for_train + nb_for_test
+
+ if len(model_indexes) == 0:
+ model_indexes = [i.item() for i in torch.randperm(len(models))]
+
+ model = models[model_indexes.pop()]
+
+ new_c_quizzes, nb_correct, ave_seq_logproba = quizz_machine.create_c_quizzes(
+ nb=nb_to_validate,
+ model_for_generation=model,
+ models_for_validation=models,
+ min_ave_seq_logproba=min_ave_seq_logproba,
+ n_epoch=n_epoch,
+ result_dir=args.result_dir,
+ logger=log_string,
+ )