Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 22:39:11 +0000 (00:39 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 22:39:11 +0000 (00:39 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 57f79a3..634363f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -392,7 +392,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis):
 
 def standard_validity(logproba):
     l = logproba.sort(dim=-1).values
-    return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95))
+    return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
+    # warnings.warn("TEST!!!", RuntimeWarning)
+    # print(l.exp())
+    # return (l[:, 0] < math.log(0.99))
 
 
 def valid_c_quizzes(recorded, criteria):
@@ -482,8 +485,8 @@ def create_c_quizzes_(
 
     # ------------------------------------------------------------
 
-    standard_validity = lambda nb_correct: torch.logical_and(
-        nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+    standard_validity = lambda nb_correct: (nb_correct >= args.min_to_validate) & (
+        nb_correct <= args.max_to_validate
     )
 
     file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
index c1477c9..0ae68d0 100755 (executable)
@@ -429,7 +429,7 @@ class QuizMachine:
                     F.cross_entropy(output.transpose(1, 2), input, reduction="none")
                     * ar_mask
                 )
-                l[:, model.id] = ce.sum(dim=-1)
+                l[:, model.id] = -ce.sum(dim=-1)
 
         return logproba