Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 18:26:15 +0000 (20:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 18:26:15 +0000 (20:26 +0200)
main.py

diff --git a/main.py b/main.py
index 02259b2..ff36e98 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -368,53 +368,35 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 ######################################################################
 
-# This is the key routine that decides what generated quizzes to keep
 
+def keep_good_quizzes(models, quizzes):
+    quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
+    token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
 
-# token_logprobas are NxMxT where M is the number of models
-# def compute_valid_quizzes_(token_logprobas):
-# warnings.warn("validation with uniform constraints", RuntimeWarning)
-# l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
-# return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
-
-# token_logprobas are NxMxT where M is the number of models
-
-
-def compute_valid_quizzes(token_logprobas):
     l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-    return (l[:, 0] < math.log(args.proba_not_understands)) & (
+
+    to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
         l[:, 1] > math.log(args.proba_understands)
     )
 
+    if args.dirty_debug:
+        # warnings.warn("DEBUG", RuntimeWarning)
+        to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5
 
-def extract_valid_quizzes_and_logprobas(recorded):
-    validated_quizzes, validated_logprobas = [], []
-    for quizzes, token_logprobas in recorded:
-        validated_indices = compute_valid_quizzes(token_logprobas)
-        validated_quizzes.append(quizzes[validated_indices])
-        validated_logprobas.append(token_logprobas[validated_indices])
-
-    if len(validated_quizzes) > 0:
-        return torch.cat(validated_quizzes, dim=0), torch.cat(
-            validated_logprobas, dim=0
-        )
-    else:
-        return None, None
+    return quizzes[to_keep]
 
 
 ######################################################################
 
 
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
+def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
     nb_to_create = nb_for_train + nb_for_test
-
-    recorded_quizzes_logprobas = []
-
+    nb_to_generate_per_iteration = nb_to_create
     nb_validated = 0
 
-    start_time = time.perf_counter()
+    recorded = []
 
-    nb_to_generate_per_iteration = nb_to_create
+    start_time = time.perf_counter()
 
     while nb_validated < nb_to_create:
         model_for_generation = models[torch.randint(len(models), (1,))]
@@ -425,19 +407,11 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
             temperature=args.generation_temperature,
         )
 
-        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+        c_quizzes = keep_good_quizzes(models, c_quizzes)
 
-        if c_quizzes.size(0) > 0:
-            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
-            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+        nb_validated += c_quizzes.size(0)
 
-            (
-                validated_quizzes,
-                validated_logprobas,
-            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
-
-            if validated_quizzes is not None:
-                nb_validated = validated_quizzes.size(0)
+        recorded.append(c_quizzes)
 
         duration = time.perf_counter() - start_time
 
@@ -454,6 +428,9 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
             f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
         )
 
+    validated_quizzes = torch.cat(recorded, dim=0)
+
+    ######################################################################
     # store the new c_quizzes which have been validated
 
     v_train = validated_quizzes[:nb_for_train]
@@ -465,20 +442,12 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
     quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False)
 
     ######################################################################
-    # save images with their logprobas
+    # save images
 
     vq = validated_quizzes[:128]
-    vl = validated_logprobas[:128]
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
-        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
-        torch.save(vl, filename)
-        # with open(file_name, "w") as logp_file:
-        # for l in vl:
-        # s = " ".join([str(x.item()) for x in l])
-        # logp_file.write(s + "\n")
-
         quiz_machine.save_quiz_illustrations(
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
@@ -630,7 +599,7 @@ for n_epoch in range(args.nb_epochs):
     # re-compute the test errors
 
     if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
+        record_new_c_quizzes(
             models,
             quiz_machine,
             nb_for_train=args.nb_new_c_quizzes_for_train,