- # ------------------------------------------------------------
-
- file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
-
- with open(file_name, "w") as logp_file:
- while (
- valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0)
- < nb_to_create
- ):
- # Select a model at random to generate the new quizzes
-
- model_for_generation = models[torch.randint(len(models), (1,))]
-
- c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
- model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
- )
-
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
-
- if c_quizzes.size(0) > 0:
- logproba = quiz_machine.logproba_of_solutions(models, c_quizzes)
- for l in logproba:
- s = " ".join([str(x.item()) for x in l])
- logp_file.write(s + "\n")
- quizzes_and_logproba_records.append((c_quizzes, logproba))
-
- nb_validated = valid_c_quizzes(
- quizzes_and_logproba_records, standard_validity
- ).size(0)
-
- log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
- )
-
- # store the new c_quizzes which have been validated
-
- new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity)
-
- quiz_machine.reverse_random_half_in_place(new_c_quizzes)
-
- quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
- quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+def compute_valid_quizzes(token_logprobas):
+ l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+ return (l[:, 0] < math.log(args.proba_not_understands)) & (
+ l[:, 1] > math.log(args.proba_understands)
+ )