Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 20:26:20 +0000 (22:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 20:26:20 +0000 (22:26 +0200)
main.py

diff --git a/main.py b/main.py
index cd9ec20..00722d6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -500,8 +500,10 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
 
     q_p, q_g = quizzes.to(local_device).chunk(2)
 
-    # Half of the samples train the prediction, and we inject noise in
-    # all, and hints in half
+    # Half of the samples train the prediction. We inject noise in all
+    # to avoid drift of the culture toward "finding waldo" type of
+    # complexity, and hints in half to allow dealing with hints when
+    # validating c quizzes
     b_p = samples_for_prediction_imt(q_p)
     b_p = add_noise_imt(b_p)
     half = torch.rand(b_p.size(0)) < 0.5
@@ -673,7 +675,7 @@ def identity_quizzes(quizzes):
 
 def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
     record = []
-    nb_validated = 0
+    nb_generated, nb_validated = 0, 0
 
     start_time = time.perf_counter()
     last_log = -1
@@ -689,12 +691,15 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
             model=model, nb=args.eval_batch_size * 10, local_device=local_device
         )
 
+        nb_generated += c_quizzes.size(0)
+
         c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
 
         if c_quizzes.size(0) > 0:
-            # Select the ones that are solved properly by some models and
-            # not understood by others
-
+            # Select the ones that are solved properly by some models
+            # and not understood by others. We add "hints" to allow
+            # the current models to deal with functionally more
+            # complex quizzes than the ones they have been trained on
             nb_correct, nb_wrong = evaluate_quizzes(
                 quizzes=c_quizzes,
                 models=models,
@@ -735,6 +740,9 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
     duration = time.perf_counter() - start_time
 
     log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
+    log_string(
+        f"validation_rate {nb_validated} / {nb_generated} ({nb_validated*100/nb_generated:.02e}%)"
+    )
 
     return torch.cat(record).to("cpu")
 
@@ -962,6 +970,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
 
+        # The c quizzes used to estimate the test accuracy have to be
+        # solvable without hints
         nb_correct, _ = evaluate_quizzes(
             quizzes=train_c_quizzes,
             models=models,