Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 18:10:29 +0000 (20:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 18:10:29 +0000 (20:10 +0200)
main.py

diff --git a/main.py b/main.py
index 6df33bd..2b71950 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -412,7 +412,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     nb_validated = torch.zeros(len(models))
 
-    while nb_validated < nb_to_create:
+    while nb_validated.sum() < nb_to_create:
         # We balance the number of quizzes per model
 
         model_for_generation = models[nb_validated.argmin()]
@@ -427,7 +427,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         c_quizzes = keep_good_quizzes(models, c_quizzes)
 
         nb_validated[model.id] += c_quizzes.size(0)
-        total_nb_validated = nb_validated.sum()
+        total_nb_validated = nb_validated.sum().item()
 
         recorded.append(c_quizzes)