Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 06:21:16 +0000 (08:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 06:21:16 +0000 (08:21 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 6b00bbf..cdaacdf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -284,8 +284,6 @@ problem.save_some_examples(args.result_dir)
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
-    nb_train_samples=args.nb_train_samples,
-    nb_test_samples=args.nb_test_samples,
     back_accuracy=back_accuracy,
     batch_size=args.physical_batch_size,
     result_dir=args.result_dir,
@@ -414,11 +412,15 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
 
     nb_validated = 0
 
+    start_time = time.perf_counter()
+
+    nb_to_generate_per_iteration = nb_to_create
+
     while nb_validated < nb_to_create:
         model_for_generation = models[torch.randint(len(models), (1,))]
 
         c_quizzes = quiz_machine.generate_quizzes(
-            nb_to_create,
+            nb_to_generate_per_iteration,
             model_for_generation=model_for_generation,
             temperature=args.generation_temperature,
         )
@@ -437,8 +439,19 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
             if validated_quizzes is not None:
                 nb_validated = validated_quizzes.size(0)
 
+        duration = time.perf_counter() - start_time
+
+        if nb_validated > 0:
+            e = (nb_to_create - nb_validated) * duration / nb_validated
+            if e > 0:
+                e = "~" + str(datetime.timedelta(seconds=int(e)))
+            else:
+                e = "0s"
+        else:
+            e = "???"
+
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (remaining time {e})"
         )
 
     # store the new c_quizzes which have been validated
@@ -595,6 +608,10 @@ if args.dirty_debug:
     args.nb_new_c_quizzes_for_train = 100
     args.nb_new_c_quizzes_for_test = 10
 
+    def compute_valid_quizzes(token_logprobas):
+        l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+        return torch.rand(l[:, 0].size(), device=l.device) < 0.5
+
 
 ######################################################################
 
index bc468d3..927a349 100755 (executable)
@@ -244,8 +244,6 @@ class QuizMachine:
     def __init__(
         self,
         problem,
-        nb_train_samples,
-        nb_test_samples,
         back_accuracy,
         batch_size,
         result_dir,