Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 16:01:40 +0000 (18:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 16:01:40 +0000 (18:01 +0200)
main.py

diff --git a/main.py b/main.py
index b2a9591..df29152 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -602,7 +602,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
 
         for s in range(proba_own_solution.size(0)):
             dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
-            if not dont_get_this_quiz.all():
+            nb_fails = dont_get_this_quiz.long().sum()
+            if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
                 for model in models:
                     if dont_get_this_quiz[model.id]:
                         assert proba_own_solution[s, model.id] < args.proba_understands