Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 11:35:21 +0000 (13:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 11:35:21 +0000 (13:35 +0200)
main.py

diff --git a/main.py b/main.py
index bb79484..b48d2a8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1319,6 +1319,14 @@ def c_quiz_criterion_one_good_one_bad(probas):
     return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
 
 
+def c_quiz_criterion_one_good_no_very_bad(probas):
+    return (
+        (probas.max(dim=1).values >= 0.75)
+        & (probas.min(dim=1).values <= 0.75)
+        & (probas.min(dim=1).values >= 0.25)
+    )
+
+
 def c_quiz_criterion_diff(probas):
     return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
 
@@ -1328,6 +1336,11 @@ def c_quiz_criterion_diff2(probas):
     return (v[:, -2] - v[:, 0]) >= 0.5
 
 
+def c_quiz_criterion_only_one(probas):
+    v = probas.sort(dim=1).values
+    return (v[:, -1] >= 0.75) & (v[:, -2] <= 0.25)
+
+
 def c_quiz_criterion_two_good(probas):
     return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2)
 
@@ -1340,7 +1353,9 @@ def c_quiz_criterion_some(probas):
 
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
+        # c_quiz_criterion_only_one,
         c_quiz_criterion_one_good_one_bad,
+        # c_quiz_criterion_one_good_no_very_bad,
         # c_quiz_criterion_diff,
         # c_quiz_criterion_diff2,
         # c_quiz_criterion_two_good,