start_time = time.perf_counter()
- nb_validated = torch.zeros(len(models))
+ nb_validated = torch.zeros(len(models), dtype=torch.int64)
while nb_validated.sum() < nb_to_create:
# We balance the number of quizzes per model
- model_for_generation = models[nb_validated.argmin()]
+ model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
+ print(nb_validated, "using", model_for_generation.id)
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
c_quizzes = keep_good_quizzes(models, c_quizzes)
- nb_validated[model.id] += c_quizzes.size(0)
+ nb_validated[model_for_generation.id] += c_quizzes.size(0)
total_nb_validated = nb_validated.sum().item()
recorded.append(c_quizzes)
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)"
)
validated_quizzes = torch.cat(recorded, dim=0)