From 3457c65462cc8ad7cacf135c994ac5bfd89a9f39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 21 Jul 2024 00:11:15 +0200 Subject: [PATCH] Update. --- main.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index c9c30c3..7588a50 100755 --- a/main.py +++ b/main.py @@ -412,7 +412,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### -def keep_good_quizzes(models, quizzes): +def keep_good_quizzes(models, quizzes, required_nb_failures=1): quizzes = quizzes[quiz_machine.non_trivial(quizzes)] if args.c_quiz_validation_mode == "proba": @@ -432,7 +432,7 @@ def keep_good_quizzes(models, quizzes): log_string(f"nb_correct {count_nc}") - to_keep = nc == (len(models) - 1) + to_keep = nc == (len(models) - required_nb_failures) else: raise ValueError(f"{args.c_quiz_validation_mode=}") @@ -452,7 +452,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_to_generate_per_iteration = nb_to_create nb_validated = 0 - recorded = [] + recorded_validated = [] + recorded_too_simple = [] start_time = time.perf_counter() @@ -471,12 +472,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 temperature_cold=args.temperature_cold, ) + recorded_too_simple.append( + keep_good_quizzes(models, c_quizzes, required_nb_failures=0) + ) + c_quizzes = keep_good_quizzes(models, c_quizzes) nb_validated[model_for_generation.id] += c_quizzes.size(0) total_nb_validated = nb_validated.sum().item() - recorded.append(c_quizzes) + recorded_validated.append(c_quizzes) duration = time.perf_counter() - start_time @@ -495,7 +500,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) - validated_quizzes = torch.cat(recorded, dim=0) + validated_quizzes = torch.cat(recorded_validated, dim=0) + too_simple_quizzes = torch.cat(recorded_too_simple, dim=0) ###################################################################### # store the new c_quizzes which have been validated @@ -519,6 +525,14 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 args.result_dir, prefix, vq, show_part_to_predict=False ) + vq = too_simple_quizzes[:128] + + if vq.size(0) > 0: + prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" + quiz_machine.save_quiz_illustrations( + args.result_dir, prefix, vq, show_part_to_predict=False + ) + ###################################################################### -- 2.39.5