Update. inv
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 07:24:07 +0000 (09:24 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 07:24:07 +0000 (09:24 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 4a0c1f5..61820dd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -451,7 +451,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
 
     while nb_validated_per_model.sum() < nb_to_validate:
-        # We balance the number of quizzes per model
+        # We use the model that has generated the fewest quizzes to
+        # balance the number of quizzes per model overall
 
         model_for_generation = sorted(
             models, key=lambda m: nb_validated_per_model[m.id]
@@ -468,29 +469,39 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             temperature_cold=args.temperature_cold,
         )
 
-        # We discard the trivial ones
+        # We discard the trivial ones, according to a criterion
+        # specific to the world quizzes (e.g. B=f(B))
 
         c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
         # We go through nb_rounds rounds and keep only quizzes on
-        # which models respond always the same through rounds and one
-        # which N-1 succeed and one fails
+        # which
+        #
+        # (1) models respond always the same through rounds, and
+        #
+        # (2) at least one and up to max_fail_to_validate model(s)
+        # fail(s)
 
-        ms = 0  # "model scores"
+        # This is nb_quizzes x nb_models
+        number_correct_responses = 0
 
         for r in range(args.nb_rounds):
-            ms += quiz_machine.models_successes(models, c_quizzes)
-            nb_sure_and_correct = (ms == r + 1).long().sum(dim=1)
-            nb_sure_and_fail = (ms == 0).long().sum(dim=1)
+            number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
+
+            nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
+            nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1)
+
             to_keep = (
-                (nb_sure_and_correct + nb_sure_and_fail == ms.size(1))
-                & (nb_sure_and_fail >= 1)
-                & (nb_sure_and_fail <= args.max_fail_to_validate)
+                (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1))
+                & (nb_sure_fail >= 1)
+                & (nb_sure_fail <= args.max_fail_to_validate)
             )
 
             c_quizzes = c_quizzes[to_keep]
-            ms = ms[to_keep]
-            print(f"Round {r} remains {c_quizzes.size(0)}")
+            number_correct_responses = number_correct_responses[to_keep]
+
+            log_string(f"round {r} remains {c_quizzes.size(0)}")
+
             if c_quizzes.size(0) == 0:
                 break
 
@@ -552,6 +563,7 @@ models = []
 
 for k in range(args.nb_gpts):
     log_string(f"creating model {k} and its w_quizzes")
+
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -568,15 +580,8 @@ for k in range(args.nb_gpts):
 
     quiz_machine.create_w_quizzes(
         model=model,
-        nb=args.nb_train_samples,
-        for_train=True,
-        p2a_only=args.p2a_only,
-    )
-
-    quiz_machine.create_w_quizzes(
-        model=model,
-        nb=args.nb_test_samples,
-        for_train=False,
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
         p2a_only=args.p2a_only,
     )
 
@@ -733,11 +738,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # Renew the training samples
 
     for model in weakest_models:
-        quiz_machine.renew_w_quizzes(
-            model=model,
-            for_train=True,
-            p2a_only=args.p2a_only,
-        )
+        quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only)
 
     if args.log_command is not None:
         s = args.log_command.split()
index b1f6be1..d6c686e 100755 (executable)
@@ -357,45 +357,47 @@ class QuizMachine:
 
     ######################################################################
 
-    def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False):
-        input = self.generate_token_sequences(nb)
+    def create_w_quizzes(
+        self, model, nb_train_samples, nb_test_samples, p2a_only=False
+    ):
+        model.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
+        model.test_w_quizzes = self.generate_token_sequences(nb_test_samples)
 
         if not p2a_only:
-            self.p_a_flip_half_in_place(input)
-
-        if for_train:
-            model.train_w_quizzes = input
-        else:
-            model.test_w_quizzes = input
+            self.p_a_flip_half_in_place(model.train_w_quizzes)
+            self.p_a_flip_half_in_place(model.test_w_quizzes)
 
     ######################################################################
 
-    def renew_w_quizzes(self, model, for_train=True, p2a_only=False):
-        input = model.train_w_quizzes if for_train else model.test_w_quizzes
-
+    def renew_train_w_quizzes(self, model, p2a_only=False):
         if for_train and hasattr(model, "hard_w_quizzes"):
             self.logger(
                 f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
             )
-            if model.hard_w_quizzes.size(0) >= input.size(0):
-                input[...] = model.hard_w_quizzes[
-                    torch.randperm(hard_w_quizzes.size(0))[input.size(0)]
+
+            if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+                model.train_w_quizzes[...] = model.hard_w_quizzes[
+                    torch.randperm(hard_w_quizzes.size(0))[
+                        model.train_w_quizzes.size(0)
+                    ]
                 ]
             else:
-                input[...] = torch.cat(
+                model.train_w_quizzes[...] = torch.cat(
                     [
                         model.hard_w_quizzes,
                         self.generate_token_sequences(
-                            input.size(0) - model.hard_w_quizzes.size(0)
+                            model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
                         ),
                     ],
                     dim=0,
                 )
         else:
-            input[...] = self.generate_token_sequences(input.size(0))
+            model.train_w_quizzes[...] = self.generate_token_sequences(
+                model.train_w_quizzes.size(0)
+            )
 
         if not p2a_only:
-            self.p_a_flip_half_in_place(input)
+            self.p_a_flip_half_in_place(model.train_w_quizzes)
 
     ######################################################################