Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 15:18:33 +0000 (17:18 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 15:18:33 +0000 (17:18 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index ca84d3a..8eeb8de 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -91,7 +91,7 @@ parser.add_argument("--max_fail_to_validate", type=int, default=1)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
 
-parser.add_argument("--proba_understands", type=float, default=0.9)
+parser.add_argument("--proba_understands", type=float, default=0.99)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
@@ -577,33 +577,24 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         # This is nb_quizzes x nb_models
 
-        number_correct_responses = 0
-        nb_remaining = [c_quizzes.size(0)]
-
-        for r in range(args.nb_rounds):
-            if c_quizzes.size(0) == 0:
-                break
-
-            number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
-
-            nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
-            nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1)
-
-            to_keep = (
-                (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1))
-                & (nb_sure_fail >= 1)
-                & (nb_sure_fail <= args.max_fail_to_validate)
-            )
-
-            if not to_keep.all():
-                rejected.append(c_quizzes[to_keep == False])
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+        )
 
-            c_quizzes = c_quizzes[to_keep]
-            number_correct_responses = number_correct_responses[to_keep]
+        probas = seq_logproba.exp()
+        nb_sure_correct = (probas >= args.proba_understands).long().sum(dim=1)
+        nb_sure_fail = (probas <= args.proba_understands).long().sum(dim=1)
 
-            nb_remaining.append(c_quizzes.size(0))
+        to_keep = (
+            (nb_sure_correct + nb_sure_fail == probas.size(1))
+            & (nb_sure_fail >= 1)
+            & (nb_sure_fail <= args.max_fail_to_validate)
+        )
 
-        to_recycle = torch.cat(rejected, dim=0) if len(rejected) > 0 else None
+        to_recycle = c_quizzes[to_keep == False] if not to_keep.all() else None
+        c_quizzes = c_quizzes[to_keep]
 
         if c_quizzes.size(0) > 0:
             nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
@@ -628,10 +619,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         else:
             e = "???"
 
-        v = " ".join([str(n) for n in nb_remaining])
-
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h) filtering {v}"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
 
     validated_quizzes = torch.cat(recorded_validated, dim=0)
@@ -651,24 +640,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
-        vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
-        number_correct_responses = 0
-
-        for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"):
-            number_correct_responses += quiz_machine.models_successes(models, vq)
-
-        seq_logproba = quiz_machine.models_logprobas(models, vq)
+        seq_logproba = quiz_machine.models_logprobas(
+            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+        ) + quiz_machine.models_logprobas(
+            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+        )
 
         comments = []
 
-        for l, r in zip(seq_logproba, number_correct_responses):
-            comments.append(
-                "nb_correct "
-                + " ".join([str(n.item()) for n in r])
-                + "\n"
-                + "proba "
-                + " ".join([str(x.item()) for x in l])
-            )
+        for l in seq_logproba:
+            comments.append(+"proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
 
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
         quiz_machine.problem.save_quizzes_as_image(
index 5dec85c..f147983 100755 (executable)
@@ -335,11 +335,13 @@ class QuizMachine:
 
     ######################################################################
 
-    def models_logprobas(self, models_for_validation, c_quizzes, device=None):
+    def models_logprobas(
+        self, models_for_validation, c_quizzes, struct, mask, device=None
+    ):
         if device is None:
             device = self.device
 
-        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+        c_quizzes = self.problem.reconfigure(c_quizzes, struct)
 
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
@@ -357,14 +359,14 @@ class QuizMachine:
                     seq_logproba.split(self.batch_size),
                 ):
                     input = input.to(device)
-                    ar_mask = self.make_ar_mask(input)
+                    ar_mask = self.make_ar_mask(input, struct, mask)
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
                             output.transpose(1, 2), input, reduction="none"
                         )
                         * ar_mask
-                    ).sum()
+                    ).sum(dim=1)
 
                 model.train(t)
 
@@ -442,6 +444,25 @@ class QuizMachine:
 
     ###############################################################
 
+    def optimize_quizzes(self, quizzes, nb_variants, nb_iterations, struct, mask):
+        for _ in range(nb_iterations):
+            candidates = quizzes[:, None].expand(-1, nb_variants, -1)
+            r = torch.rand(candidates.size(), device=candidates.device)
+            u = r.reshape(
+                candidates.size(0) * candidates.size(1), 4, candidates.size(2) // 4
+            )
+            u[:, :, 0] = 0
+            u = u * torch.tensor(mask, device=u.device)[None, :, None]
+            random_mask = (r.sort(dim=0, descending=True).indices == 0).long()
+            random_mask[:, 0] = 0
+            candidates.reshape(-1, candidates.size(-1))
+            random_mask.reshape(candidates.size())
+            random_tokens = torch.randint(
+                self.problem.nb_token_values - 4, random_mask.size()
+            )
+            candidates = (1 - random_mask) * candidates + random_mask * random_tokens
+            ar_mask = (self.make_ar_mask(candidates, struct, make_ar_mask),)
+
     def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
         seq_logproba = torch.zeros(nb, device=self.device)