Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 21:13:22 +0000 (23:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 21:13:22 +0000 (23:13 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 94c030a..69452e6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -104,6 +104,8 @@ parser.add_argument("--temperature_cold", type=float, default=1)
 
 parser.add_argument("--prompt_noise", type=float, default=0.0)
 
+parser.add_argument("--nb_averaging_rounds", type=int, default=1)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test_generator", action="store_true", default=False)
@@ -343,6 +345,7 @@ quiz_machine = quiz_machine.QuizMachine(
     batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
     prompt_noise=args.prompt_noise,
+    nb_averaging_rounds=args.nb_averaging_rounds,
     logger=log_string,
     device=main_device,
 )
index 0fdfbf6..cfab73a 100755 (executable)
@@ -68,6 +68,7 @@ class QuizMachine:
         batch_size,
         result_dir,
         prompt_noise,
+        nb_averaging_rounds,
         logger,
         device=torch.device("cpu"),
     ):
@@ -79,7 +80,11 @@ class QuizMachine:
         self.logger = logger
         self.prompt_len = None
         self.answer_len = None
+
+        assert prompt_noise > 0 or nb_averaging_rounds == 1
+
         self.prompt_noise = prompt_noise
+        self.nb_averaging_rounds = nb_averaging_rounds
 
         self.understood_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
@@ -338,39 +343,40 @@ class QuizMachine:
 
         c_quizzes = self.problem.reconfigure(c_quizzes, struct)
 
-        if self.prompt_noise > 0.0 and noise_mask is not None:
-            c_quizzes = self.problem.inject_noise(
-                c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
-            )
-
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
             device=device,
         )
 
-        for model in models_for_validation:
-            with torch.autograd.no_grad():
-                t = model.training
-                model.eval()
-
-                for input, l in zip(
-                    c_quizzes.split(self.batch_size),
-                    seq_logproba.split(self.batch_size),
-                ):
-                    input = input.to(device)
-                    ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
-                    output = model(mygpt.BracketedSequence(input)).x
-                    l[:, model.id] = (
-                        -F.cross_entropy(
-                            output.transpose(1, 2), input, reduction="none"
-                        )
-                        * ar_mask
-                    ).sum(dim=1)
-
-                model.train(t)
-
-        return seq_logproba.to("cpu")
+        for a in range(self.nb_averaging_rounds):
+            if self.prompt_noise > 0.0 and noise_mask is not None:
+                c_quizzes = self.problem.inject_noise(
+                    c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
+                )
+
+            for model in models_for_validation:
+                with torch.autograd.no_grad():
+                    t = model.training
+                    model.eval()
+
+                    for input, l in zip(
+                        c_quizzes.split(self.batch_size),
+                        seq_logproba.split(self.batch_size),
+                    ):
+                        input = input.to(device)
+                        ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+                        output = model(mygpt.BracketedSequence(input)).x
+                        l[:, model.id] += (
+                            -F.cross_entropy(
+                                output.transpose(1, 2), input, reduction="none"
+                            )
+                            * ar_mask
+                        ).sum(dim=1)
+
+                    model.train(t)
+
+        return seq_logproba.div(self.nb_averaging_rounds).to("cpu")
 
     ######################################################################