Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 22:11:15 +0000 (00:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 22:11:15 +0000 (00:11 +0200)
main.py

diff --git a/main.py b/main.py
index c9c30c3..7588a50 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -412,7 +412,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 ######################################################################
 
 
-def keep_good_quizzes(models, quizzes):
+def keep_good_quizzes(models, quizzes, required_nb_failures=1):
     quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
 
     if args.c_quiz_validation_mode == "proba":
@@ -432,7 +432,7 @@ def keep_good_quizzes(models, quizzes):
 
         log_string(f"nb_correct {count_nc}")
 
-        to_keep = nc == (len(models) - 1)
+        to_keep = nc == (len(models) - required_nb_failures)
 
     else:
         raise ValueError(f"{args.c_quiz_validation_mode=}")
@@ -452,7 +452,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     nb_to_generate_per_iteration = nb_to_create
     nb_validated = 0
 
-    recorded = []
+    recorded_validated = []
+    recorded_too_simple = []
 
     start_time = time.perf_counter()
 
@@ -471,12 +472,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             temperature_cold=args.temperature_cold,
         )
 
+        recorded_too_simple.append(
+            keep_good_quizzes(models, c_quizzes, required_nb_failures=0)
+        )
+
         c_quizzes = keep_good_quizzes(models, c_quizzes)
 
         nb_validated[model_for_generation.id] += c_quizzes.size(0)
         total_nb_validated = nb_validated.sum().item()
 
-        recorded.append(c_quizzes)
+        recorded_validated.append(c_quizzes)
 
         duration = time.perf_counter() - start_time
 
@@ -495,7 +500,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
 
-    validated_quizzes = torch.cat(recorded, dim=0)
+    validated_quizzes = torch.cat(recorded_validated, dim=0)
+    too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
 
     ######################################################################
     # store the new c_quizzes which have been validated
@@ -519,6 +525,14 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
 
+    vq = too_simple_quizzes[:128]
+
+    if vq.size(0) > 0:
+        prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
+        quiz_machine.save_quiz_illustrations(
+            args.result_dir, prefix, vq, show_part_to_predict=False
+        )
+
 
 ######################################################################