Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 19:06:10 +0000 (21:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 19:06:10 +0000 (21:06 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index a165696..0a148b1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,7 +89,7 @@ parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=1)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98)
 
 parser.add_argument("--proba_understands", type=float, default=0.99)
 
@@ -584,16 +584,17 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         )
 
         probas = seq_logproba.exp()
-        nb_sure_correct = (probas >= args.proba_understands).long().sum(dim=1)
-        nb_sure_fail = (probas <= args.proba_understands).long().sum(dim=1)
+
+        nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
+        nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
 
         to_keep = (
-            (nb_sure_correct + nb_sure_fail == probas.size(1))
-            & (nb_sure_fail >= 1)
-            & (nb_sure_fail <= args.max_fail_to_validate)
+            (nb_succeed + nb_fail == probas.size(1))
+            & (nb_fail >= 1)
+            & (nb_fail <= args.max_fail_to_validate)
         )
 
-        to_recycle = c_quizzes[to_keep == False] if not to_keep.all() else None
+        to_recycle = c_quizzes[to_keep == False]
         c_quizzes = c_quizzes[to_keep]
 
         if c_quizzes.size(0) > 0:
index f147983..34f6b62 100755 (executable)
@@ -444,24 +444,40 @@ class QuizMachine:
 
     ###############################################################
 
-    def optimize_quizzes(self, quizzes, nb_variants, nb_iterations, struct, mask):
+    def optimize_quizzes(self, quiz, nb_variants, nb_iterations, struct, mask):
         for _ in range(nb_iterations):
-            candidates = quizzes[:, None].expand(-1, nb_variants, -1)
+            candidates = quizzes[None].expand(nb_variants, -1)
             r = torch.rand(candidates.size(), device=candidates.device)
-            u = r.reshape(
-                candidates.size(0) * candidates.size(1), 4, candidates.size(2) // 4
-            )
+            u = r.reshape(r.size(0), 4, candidates.size(1) // 4)
+            # Only change the part indicated by the mask and do not
+            # touch the special tokens
             u[:, :, 0] = 0
             u = u * torch.tensor(mask, device=u.device)[None, :, None]
             random_mask = (r.sort(dim=0, descending=True).indices == 0).long()
-            random_mask[:, 0] = 0
+            # Keep the first unchanged
+            random_mask[:, 0, :] = 0
+            # Reshape without the 4 parts
             candidates.reshape(-1, candidates.size(-1))
             random_mask.reshape(candidates.size())
             random_tokens = torch.randint(
                 self.problem.nb_token_values - 4, random_mask.size()
             )
+            # Apply the noise
             candidates = (1 - random_mask) * candidates + random_mask * random_tokens
-            ar_mask = (self.make_ar_mask(candidates, struct, make_ar_mask),)
+            seq_logproba = quiz_machine.models_logprobas(
+                models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            ) + quiz_machine.models_logprobas(
+                models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+            )
+            sorted_logprobas = seq_logproba.sort(dim=1).values.exp()
+            lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1]
+            score = second_lowest - lowest
+
+            score = score * (second_lowest > args.proba_understands)
+
+            quiz = candidates[score.argmax()]
+
+        return quiz
 
     def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
         seq_logproba = torch.zeros(nb, device=self.device)
@@ -484,10 +500,11 @@ class QuizMachine:
                 logit_transformer=t,
             )
 
-            if to_recycle is not None:
+            if to_recycle is not None and to_recycle.size(0) > 0:
                 to_recycle = self.problem.reconfigure(to_recycle, s)
                 c_quizzes[: to_recycle.size(0)] = to_recycle
-                to_recycle = None
+
+            to_recycle = None
 
         c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))