- nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
+ nb_to_generate = nb_for_train + nb_for_test
+
+ if len(model_indexes) == 0:
+ model_indexes = [i.item() for i in torch.randperm(len(models))]
+
+ model = models[model_indexes.pop()]
+
+ new_c_quizzes, nb_correct, ave_seq_logproba = quizz_machine.create_c_quizzes(
+ nb=nb_to_generate,
+ model_for_generation=model,
+ models_for_validation=models,
+ min_ave_seq_logproba=min_ave_seq_logproba,