Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 21:00:50 +0000 (23:00 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 21:00:50 +0000 (23:00 +0200)
main.py
problem.py
quiz_machine.py

diff --git a/main.py b/main.py
index 3004f9c..57f79a3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -392,7 +392,7 @@ def run_tests(model, quiz_machine, deterministic_synthesis):
 
 def standard_validity(logproba):
     l = logproba.sort(dim=-1).values
-    return logical_and(l[0] < math.log(0.5), l[1] > math.log(0.95))
+    return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95))
 
 
 def valid_c_quizzes(recorded, criteria):
@@ -435,17 +435,10 @@ def create_c_quizzes(
             c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
             if c_quizzes.size(0) > 0:
-                logproba = c_quizzes.new(c_quizzes.size(0), len(models))
-                for q, l in zip(
-                    c_quizzes.split(args.batch_size), logproba.split(args.batch_size)
-                ):
-                    for model in models:
-                        l[model.id] = F.cross_entropy(model(q))
-
+                logproba = quiz_machine.logproba_solution(models, c_quizzes)
                 for l in logproba:
                     s = " ".join([str(x.item()) for x in l])
                     logp_file.write(s + "\n")
-
                 quizzes_and_logproba_records.append((c_quizzes, logproba))
 
             nb_validated = valid_c_quizzes(
@@ -655,6 +648,9 @@ for n_epoch in range(args.nb_epochs):
     ##################################################
     # Replace a fraction of the w_quizzes with fresh ones
 
+    log_string(
+        f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
+    )
     quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
 
     ##################################################
index eceb904..617b2a8 100755 (executable)
@@ -19,6 +19,12 @@ class Problem:
         else:
             self.queue = None
 
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
     def nb_token_values(self):
         pass
 
index 321df35..c1477c9 100755 (executable)
@@ -416,6 +416,25 @@ class QuizMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
+    def logproba_solution(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+        for model in models:
+            for input, l in zip(
+                c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+            ):
+                ar_mask = self.make_ar_mask(input)
+                output = model(mygpt.BracketedSequence(input)).x
+                ce = (
+                    F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * ar_mask
+                )
+                l[:, model.id] = ce.sum(dim=-1)
+
+        return logproba
+
+    ###############################################################
+
     def compute_correctness(
         self,
         c_quizzes,