From 563de696ca14339aa4d514b09acaea5442fdb002 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 29 Jul 2024 10:14:17 +0200 Subject: [PATCH] Update. --- main.py | 20 +++++++++++++------- mygpt.py | 28 +++++++++++++++------------- quiz_machine.py | 5 +++-- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index 0a148b1..e553278 100755 --- a/main.py +++ b/main.py @@ -451,18 +451,24 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### -lt_noisy = lambda s, logits: logits / args.temperature_hot -lt_clean = lambda s, logits: logits / args.temperature_cold + +def model_transformer_hot(model): + model.temperature = args.temperature_hot + + +def model_transformer_cold(model): + model.temperature = args.temperature_cold + c_quizzes_procedure = [ - (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), lt_noisy), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), lt_clean), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), lt_clean), + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] c_quizzes_procedure_ = [ - (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), lt_noisy), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), lt_clean), + (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), model_transformer_hot), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold), ] diff --git a/mygpt.py b/mygpt.py index e2f317f..c073113 100755 --- a/mygpt.py +++ b/mygpt.py @@ -237,19 +237,6 @@ class NoiseInjector(nn.Module): return x -def set_noise_injection(model, noise_std, identifier=None): - for m in model.modules(): - if isinstance(m, NoiseInjector): - if identifier is None or identifier == m.identifier: - m.noise_std = noise_std - - -def reset_noise_injection(model): - for m in model.modules(): - if isinstance(m, NoiseInjector): - m.noise_std = 0.0 - - ############################## @@ -271,6 +258,8 @@ class MyGPT(nn.Module): assert dim_model % nb_heads == 0 + self.temperature = 1.0 + self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), AddPositionalEncoding(len_max), @@ -345,6 +334,7 @@ class MyGPT(nn.Module): bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) + bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature return bs def encode(self, bs): @@ -374,6 +364,18 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs + def reset_transformations(self): + self.temperature = 1.0 + for m in self.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = 0.0 + + def set_noise_injection(self, noise_std, identifier=None): + for m in model.modules(): + if isinstance(m, NoiseInjector): + if identifier is None or identifier == m.identifier: + m.noise_std = noise_std + def record_attention(self, v=True): for m in self.modules(): if isinstance(m, QKVAttention): diff --git a/quiz_machine.py b/quiz_machine.py index 93e048d..cf70b91 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -493,7 +493,8 @@ class QuizMachine: c_quizzes = self.problem.reconfigure(c_quizzes, s) pred_s = s - mt(model_for_generation) + if mt is not None: + mt(model_for_generation) self.autoregression( model=model_for_generation, @@ -502,7 +503,7 @@ class QuizMachine: seq_logproba=seq_logproba, ) - model_for_generation.reset_transformation() + model_for_generation.reset_transformations() if to_recycle is not None and to_recycle.size(0) > 0: to_recycle = self.problem.reconfigure(to_recycle, s) -- 2.20.1