quiz_machine = quiz_machine.QuizMachine(
problem=problem,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
back_accuracy=back_accuracy,
batch_size=args.physical_batch_size,
result_dir=args.result_dir,
nb_validated = 0
+ start_time = time.perf_counter()
+
+ nb_to_generate_per_iteration = nb_to_create
+
while nb_validated < nb_to_create:
model_for_generation = models[torch.randint(len(models), (1,))]
c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
+ nb_to_generate_per_iteration,
model_for_generation=model_for_generation,
temperature=args.generation_temperature,
)
if validated_quizzes is not None:
nb_validated = validated_quizzes.size(0)
+ duration = time.perf_counter() - start_time
+
+ if nb_validated > 0:
+ e = (nb_to_create - nb_validated) * duration / nb_validated
+ if e > 0:
+ e = "~" + str(datetime.timedelta(seconds=int(e)))
+ else:
+ e = "0s"
+ else:
+ e = "???"
+
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (remaining time {e})"
)
# store the new c_quizzes which have been validated
args.nb_new_c_quizzes_for_train = 100
args.nb_new_c_quizzes_for_test = 10
+ def compute_valid_quizzes(token_logprobas):
+ l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+ return torch.rand(l[:, 0].size(), device=l.device) < 0.5
+
######################################################################