Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 2 Aug 2024 04:42:32 +0000 (06:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 2 Aug 2024 04:42:32 +0000 (06:42 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 69452e6..059a29d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -104,7 +104,7 @@ parser.add_argument("--temperature_cold", type=float, default=1)
 
 parser.add_argument("--prompt_noise", type=float, default=0.0)
 
-parser.add_argument("--nb_averaging_rounds", type=int, default=1)
+parser.add_argument("--nb_averaging_rounds", type=int, default=3)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -162,7 +162,7 @@ assert not args.grids_science_tasks or (
 default_args = {
     "model": "37M",
     "batch_size": 25,
-    "inference_batch_size": 100,
+    "inference_batch_size": 50,
     "nb_train_samples": 100000,
     "nb_test_samples": 10000,
 }
@@ -345,7 +345,6 @@ quiz_machine = quiz_machine.QuizMachine(
     batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
     prompt_noise=args.prompt_noise,
-    nb_averaging_rounds=args.nb_averaging_rounds,
     logger=log_string,
     device=main_device,
 )
@@ -581,15 +580,20 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         c_quizzes = c_quizzes[to_keep]
 
-        # This is nb_quizzes x nb_models
+        probas = 0
 
-        seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
+        for a in range(args.nb_averaging_rounds):
+            # This is nb_quizzes x nb_models
 
-        probas = seq_logproba.exp()
+            seq_logproba = quiz_machine.models_logprobas(
+                models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            ) + quiz_machine.models_logprobas(
+                models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            )
+
+            probas += seq_logproba.exp()
+
+        probas /= args.nb_averaging_rounds
 
         nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
         nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
@@ -650,11 +654,20 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
-        seq_logproba = quiz_machine.models_logprobas(
-            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
+        probas = 0
+
+        for a in range(args.nb_averaging_rounds):
+            # This is nb_quizzes x nb_models
+
+            seq_logproba = quiz_machine.models_logprobas(
+                models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            ) + quiz_machine.models_logprobas(
+                models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            )
+
+            probas += seq_logproba.exp()
+
+        probas /= args.nb_averaging_rounds
 
         comments = []
 
index cfab73a..015f6d2 100755 (executable)
@@ -68,7 +68,6 @@ class QuizMachine:
         batch_size,
         result_dir,
         prompt_noise,
-        nb_averaging_rounds,
         logger,
         device=torch.device("cpu"),
     ):
@@ -80,11 +79,7 @@ class QuizMachine:
         self.logger = logger
         self.prompt_len = None
         self.answer_len = None
-
-        assert prompt_noise > 0 or nb_averaging_rounds == 1
-
         self.prompt_noise = prompt_noise
-        self.nb_averaging_rounds = nb_averaging_rounds
 
         self.understood_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
@@ -349,34 +344,33 @@ class QuizMachine:
             device=device,
         )
 
-        for a in range(self.nb_averaging_rounds):
-            if self.prompt_noise > 0.0 and noise_mask is not None:
-                c_quizzes = self.problem.inject_noise(
-                    c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
-                )
+        if self.prompt_noise > 0.0 and noise_mask is not None:
+            c_quizzes = self.problem.inject_noise(
+                c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
+            )
 
-            for model in models_for_validation:
-                with torch.autograd.no_grad():
-                    t = model.training
-                    model.eval()
-
-                    for input, l in zip(
-                        c_quizzes.split(self.batch_size),
-                        seq_logproba.split(self.batch_size),
-                    ):
-                        input = input.to(device)
-                        ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
-                        output = model(mygpt.BracketedSequence(input)).x
-                        l[:, model.id] += (
-                            -F.cross_entropy(
-                                output.transpose(1, 2), input, reduction="none"
-                            )
-                            * ar_mask
-                        ).sum(dim=1)
-
-                    model.train(t)
-
-        return seq_logproba.div(self.nb_averaging_rounds).to("cpu")
+        for model in models_for_validation:
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+
+                for input, l in zip(
+                    c_quizzes.split(self.batch_size),
+                    seq_logproba.split(self.batch_size),
+                ):
+                    input = input.to(device)
+                    ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+                    output = model(mygpt.BracketedSequence(input)).x
+                    l[:, model.id] = (
+                        -F.cross_entropy(
+                            output.transpose(1, 2), input, reduction="none"
+                        )
+                        * ar_mask
+                    ).sum(dim=1)
+
+                model.train(t)
+
+        return seq_logproba.to("cpu")
 
     ######################################################################