Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 08:14:17 +0000 (10:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 08:14:17 +0000 (10:14 +0200)
main.py
mygpt.py
quiz_machine.py

diff --git a/main.py b/main.py
index 0a148b1..e553278 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -451,18 +451,24 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 ######################################################################
 
-lt_noisy = lambda s, logits: logits / args.temperature_hot
-lt_clean = lambda s, logits: logits / args.temperature_cold
+
+def model_transformer_hot(model):
+    model.temperature = args.temperature_hot
+
+
+def model_transformer_cold(model):
+    model.temperature = args.temperature_cold
+
 
 c_quizzes_procedure = [
-    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), lt_noisy),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), lt_clean),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), lt_clean),
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
 ]
 
 c_quizzes_procedure_ = [
-    (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), lt_noisy),
-    (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), lt_clean),
+    (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), model_transformer_hot),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold),
 ]
 
 
index e2f317f..c073113 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -237,19 +237,6 @@ class NoiseInjector(nn.Module):
         return x
 
 
-def set_noise_injection(model, noise_std, identifier=None):
-    for m in model.modules():
-        if isinstance(m, NoiseInjector):
-            if identifier is None or identifier == m.identifier:
-                m.noise_std = noise_std
-
-
-def reset_noise_injection(model):
-    for m in model.modules():
-        if isinstance(m, NoiseInjector):
-            m.noise_std = 0.0
-
-
 ##############################
 
 
@@ -271,6 +258,8 @@ class MyGPT(nn.Module):
 
         assert dim_model % nb_heads == 0
 
+        self.temperature = 1.0
+
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
             AddPositionalEncoding(len_max),
@@ -345,6 +334,7 @@ class MyGPT(nn.Module):
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
+        bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
         return bs
 
     def encode(self, bs):
@@ -374,6 +364,18 @@ class MyGPT(nn.Module):
             bs = self.readout(bs)
             return bs
 
+    def reset_transformations(self):
+        self.temperature = 1.0
+        for m in self.modules():
+            if isinstance(m, NoiseInjector):
+                m.noise_std = 0.0
+
+    def set_noise_injection(self, noise_std, identifier=None):
+        for m in model.modules():
+            if isinstance(m, NoiseInjector):
+                if identifier is None or identifier == m.identifier:
+                    m.noise_std = noise_std
+
     def record_attention(self, v=True):
         for m in self.modules():
             if isinstance(m, QKVAttention):
index 93e048d..cf70b91 100755 (executable)
@@ -493,7 +493,8 @@ class QuizMachine:
                 c_quizzes = self.problem.reconfigure(c_quizzes, s)
             pred_s = s
 
-            mt(model_for_generation)
+            if mt is not None:
+                mt(model_for_generation)
 
             self.autoregression(
                 model=model_for_generation,
@@ -502,7 +503,7 @@ class QuizMachine:
                 seq_logproba=seq_logproba,
             )
 
-            model_for_generation.reset_transformation()
+            model_for_generation.reset_transformations()
 
             if to_recycle is not None and to_recycle.size(0) > 0:
                 to_recycle = self.problem.reconfigure(to_recycle, s)