Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 08:07:57 +0000 (10:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 08:07:57 +0000 (10:07 +0200)
mygpt.py
quiz_machine.py

index fca2067..e2f317f 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -226,9 +226,10 @@ class QKVAttention(nn.Module):
 
 
 class NoiseInjector(nn.Module):
-    def __init__(self):
+    def __init__(self, identifier=None):
         super().__init__()
         self.noise_std = 0.0
+        self.identifier = identifier
 
     def forward(self, x):
         if self.noise_std > 0:
@@ -236,10 +237,17 @@ class NoiseInjector(nn.Module):
         return x
 
 
-def set_noise_injection(model, noise_std):
+def set_noise_injection(model, noise_std, identifier=None):
     for m in model.modules():
         if isinstance(m, NoiseInjector):
-            m.noise_std = noise_std
+            if identifier is None or identifier == m.identifier:
+                m.noise_std = noise_std
+
+
+def reset_noise_injection(model):
+    for m in model.modules():
+        if isinstance(m, NoiseInjector):
+            m.noise_std = 0.0
 
 
 ##############################
@@ -275,7 +283,7 @@ class MyGPT(nn.Module):
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
+                        NoiseInjector(identifier=("attention", b)),
                     ),
                     QKVAttention(
                         dim_in=dim_model,
@@ -289,7 +297,7 @@ class MyGPT(nn.Module):
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
+                        NoiseInjector(identifier=("ffw", b)),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
index ca71c95..93e048d 100755 (executable)
@@ -29,7 +29,6 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    logit_transformer=None,
     deterministic_synthesis=False,
 ):
     if input.size(0) == 0:
@@ -46,9 +45,6 @@ def one_batch_masked_inplace_autoregression(
 
         logits = output[:, s]
 
-        if logit_transformer is not None:
-            logits = logit_transformer(s, logits).log_softmax(dim=-1)
-
         if deterministic_synthesis:
             t_next = logits.argmax(-1)
         else:
@@ -107,7 +103,6 @@ class QuizMachine:
         input,
         ar_mask,
         seq_logproba=None,
-        logit_transformer=None,
         progress_bar_desc=None,
     ):
         assert input.size() == ar_mask.size()
@@ -139,7 +134,6 @@ class QuizMachine:
                     input=input,
                     ar_mask=ar_mask,
                     seq_logproba=seq_logproba,
-                    logit_transformer=logit_transformer,
                     deterministic_synthesis=False,
                 )
 
@@ -484,12 +478,14 @@ class QuizMachine:
 
         return quiz.to("cpu")
 
+    ######################################################################
+
     def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
         seq_logproba = torch.zeros(nb, device=self.device)
 
         c_quizzes = None
 
-        for s, m, t in procedure:
+        for s, m, mt in procedure:
             if c_quizzes is None:
                 c_quizzes = self.problem.create_empty_quizzes(nb, s)
                 c_quizzes = c_quizzes.to(self.device)
@@ -497,14 +493,17 @@ class QuizMachine:
                 c_quizzes = self.problem.reconfigure(c_quizzes, s)
             pred_s = s
 
+            mt(model_for_generation)
+
             self.autoregression(
                 model=model_for_generation,
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes, s, m),
                 seq_logproba=seq_logproba,
-                logit_transformer=t,
             )
 
+            model_for_generation.reset_transformation()
+
             if to_recycle is not None and to_recycle.size(0) > 0:
                 to_recycle = self.problem.reconfigure(to_recycle, s)
                 c_quizzes[: to_recycle.size(0)] = to_recycle
@@ -516,153 +515,3 @@ class QuizMachine:
         return c_quizzes.to("cpu")
 
     ######################################################################
-
-    def generate_c_quizzes_orig(
-        self,
-        nb,
-        model_for_generation,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-        to_recycle=None,
-    ):
-        c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B"))
-        c_quizzes = c_quizzes.to(self.device)
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-        lt_clean = lambda s, logits: logits / temperature_cold
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-        )
-
-        if to_recycle is not None:
-            l = c_quizzes.size(1) // 4
-            self.logger(f"recycling {to_recycle.size(0)} rejected quizzes")
-            c_quizzes[: to_recycle.size(0), :l] = to_recycle[:, 3 * l :]
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-        )
-
-        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-        )
-
-        return c_quizzes.to("cpu")
-
-    ######################################################################
-
-    def generate_c_quizzes_(
-        self,
-        nb,
-        model_for_generation,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-    ):
-        warnings.warn(
-            "**************************** simple quiz generation", RuntimeWarning
-        )
-
-        c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
-        c_quizzes = c_quizzes.to(self.device)
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-        )
-
-        return c_quizzes.to("cpu")
-
-    ######################################################################
-
-    def generate_c_quizzes_2(
-        self,
-        nb,
-        model_for_generation,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-    ):
-        warnings.warn(
-            "**************************** simple quiz generation", RuntimeWarning
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-        lt_clean = lambda s, logits: logits / temperature_cold
-
-        c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
-        c_quizzes = c_quizzes.to(self.device)
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 0, 0)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-        )
-
-        c_quizzes2 = self.problem.create_empty_quizzes(nb, ("B", "f_B", "A", "f_A"))
-        c_quizzes2 = c_quizzes2.to(self.device)
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes2,
-            ar_mask=self.make_ar_mask(
-                c_quizzes2,
-                ("B", "f_B", "A", "f_A"),
-                (1, 0, 0, 0),
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-        )
-
-        l = c_quizzes.size(1) // 4
-        c_quizzes[:, 2 * l : 3 * l] = c_quizzes2[:, :l]
-
-        self.autoregression(
-            model=model_for_generation,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(
-                c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-            ),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-        )
-
-        return c_quizzes.to("cpu")
-
-    ######################################################################