Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 11:43:26 +0000 (13:43 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 28 Jul 2024 11:43:26 +0000 (13:43 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 83fd8b8..ca84d3a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -651,16 +651,25 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
+        vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
         number_correct_responses = 0
 
         for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"):
             number_correct_responses += quiz_machine.models_successes(models, vq)
 
+        seq_logproba = quiz_machine.models_logprobas(models, vq)
+
         comments = []
-        for r in number_correct_responses:
-            comments.append("nb_correct " + " ".join([str(n.item()) for n in r]))
 
-        vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
+        for l, r in zip(seq_logproba, number_correct_responses):
+            comments.append(
+                "nb_correct "
+                + " ".join([str(n.item()) for n in r])
+                + "\n"
+                + "proba "
+                + " ".join([str(x.item()) for x in l])
+            )
+
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir, filename, vq, comments=comments
@@ -906,7 +915,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             model.main_test_accuracy = 0.0
 
     ##################################################
-    # Select, improve, and eval the worst model
+    # Select, improve, and eval the worst model(s)
 
     ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
 
index ba3387c..5dec85c 100755 (executable)
@@ -335,56 +335,65 @@ class QuizMachine:
 
     ######################################################################
 
-    def solution_token_logprobas(self, models, c_quizzes):
-        logproba = c_quizzes.new_zeros(
+    def models_logprobas(self, models_for_validation, c_quizzes, device=None):
+        if device is None:
+            device = self.device
+
+        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+
+        seq_logproba = torch.zeros(
             c_quizzes.size(0),
-            len(models),
-            c_quizzes.size(1),
-            device=self.device,
-            dtype=torch.float32,
+            max([m.id for m in models_for_validation]) + 1,
+            device=device,
         )
 
-        for model in models:
+        for model in models_for_validation:
             with torch.autograd.no_grad():
                 t = model.training
                 model.eval()
 
                 for input, l in zip(
-                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                    c_quizzes.split(self.batch_size),
+                    seq_logproba.split(self.batch_size),
                 ):
-                    input = input.to(self.device)
-                    ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123")
+                    input = input.to(device)
+                    ar_mask = self.make_ar_mask(input)
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
                             output.transpose(1, 2), input, reduction="none"
                         )
                         * ar_mask
-                    )
+                    ).sum()
 
                 model.train(t)
 
-        return logproba.to("cpu")
+        return seq_logproba.to("cpu")
 
     ###############################################################
 
-    def models_successes(self, models_for_validation, c_quizzes):
+    def models_successes(self, models_for_validation, c_quizzes, device=None):
+        if device is None:
+            device = self.device
+
+        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
-            device=self.device,
+            device=device,
         )
 
         correctly_solved = torch.empty(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
-            device=self.device,
+            device=device,
             dtype=torch.int64,
         )
 
         seq_logproba[...] = 0.0
 
-        c_quizzes = c_quizzes.to(self.device)
+        c_quizzes = c_quizzes.to(device)
 
         reversed_c_quizzes = self.problem.reconfigure(
             c_quizzes, ("f_A", "A", "f_B", "B")