Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 18:47:42 +0000 (20:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 18:47:42 +0000 (20:47 +0200)
main.py

diff --git a/main.py b/main.py
index 2b71950..41efc86 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -410,12 +410,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     start_time = time.perf_counter()
 
-    nb_validated = torch.zeros(len(models))
+    nb_validated = torch.zeros(len(models), dtype=torch.int64)
 
     while nb_validated.sum() < nb_to_create:
         # We balance the number of quizzes per model
 
-        model_for_generation = models[nb_validated.argmin()]
+        model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
+        print(nb_validated, "using", model_for_generation.id)
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
@@ -426,7 +427,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         c_quizzes = keep_good_quizzes(models, c_quizzes)
 
-        nb_validated[model.id] += c_quizzes.size(0)
+        nb_validated[model_for_generation.id] += c_quizzes.size(0)
         total_nb_validated = nb_validated.sum().item()
 
         recorded.append(c_quizzes)
@@ -442,7 +443,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration}/h)"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)"
         )
 
     validated_quizzes = torch.cat(recorded, dim=0)