- new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
- n_epoch=n_epoch,
- result_dir=args.result_dir,
- logger=log_string,
- nb=nb_to_generate,
- model=model,
- other_models=other_models,
- desired_average_logits=desired_average_logits,
+ # ------------------------------------------------------------
+
+ standard_validity = lambda nb_correct: torch.logical_and(
+ nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+ )
+
+ while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create:
+ model_for_generation = models[torch.randint(len(models), (1,))]
+
+ c_quizzes, ave_seq_logproba = quizz_machine.generate_quizzes(
+ nb_to_create,
+ model_for_generation=model_for_generation,
+ reverse_cleanup=args.reverse_cleanup,