######################################################################
-def keep_good_quizzes(models, quizzes):
+def keep_good_quizzes(models, quizzes, required_nb_failures=1):
quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
if args.c_quiz_validation_mode == "proba":
log_string(f"nb_correct {count_nc}")
- to_keep = nc == (len(models) - 1)
+ to_keep = nc == (len(models) - required_nb_failures)
else:
raise ValueError(f"{args.c_quiz_validation_mode=}")
nb_to_generate_per_iteration = nb_to_create
nb_validated = 0
- recorded = []
+ recorded_validated = []
+ recorded_too_simple = []
start_time = time.perf_counter()
temperature_cold=args.temperature_cold,
)
+ recorded_too_simple.append(
+ keep_good_quizzes(models, c_quizzes, required_nb_failures=0)
+ )
+
c_quizzes = keep_good_quizzes(models, c_quizzes)
nb_validated[model_for_generation.id] += c_quizzes.size(0)
total_nb_validated = nb_validated.sum().item()
- recorded.append(c_quizzes)
+ recorded_validated.append(c_quizzes)
duration = time.perf_counter() - start_time
f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
)
- validated_quizzes = torch.cat(recorded, dim=0)
+ validated_quizzes = torch.cat(recorded_validated, dim=0)
+ too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
######################################################################
# store the new c_quizzes which have been validated
args.result_dir, prefix, vq, show_part_to_predict=False
)
+ vq = too_simple_quizzes[:128]
+
+ if vq.size(0) > 0:
+ prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
+ quiz_machine.save_quiz_illustrations(
+ args.result_dir, prefix, vq, show_part_to_predict=False
+ )
+
######################################################################