From 6a7c53e919d781b77c490aee14c2a61fa5af6407 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 29 Jul 2024 10:07:57 +0200 Subject: [PATCH] Update. --- mygpt.py | 18 ++++-- quiz_machine.py | 165 ++---------------------------------------------- 2 files changed, 20 insertions(+), 163 deletions(-) diff --git a/mygpt.py b/mygpt.py index fca2067..e2f317f 100755 --- a/mygpt.py +++ b/mygpt.py @@ -226,9 +226,10 @@ class QKVAttention(nn.Module): class NoiseInjector(nn.Module): - def __init__(self): + def __init__(self, identifier=None): super().__init__() self.noise_std = 0.0 + self.identifier = identifier def forward(self, x): if self.noise_std > 0: @@ -236,10 +237,17 @@ class NoiseInjector(nn.Module): return x -def set_noise_injection(model, noise_std): +def set_noise_injection(model, noise_std, identifier=None): for m in model.modules(): if isinstance(m, NoiseInjector): - m.noise_std = noise_std + if identifier is None or identifier == m.identifier: + m.noise_std = noise_std + + +def reset_noise_injection(model): + for m in model.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = 0.0 ############################## @@ -275,7 +283,7 @@ class MyGPT(nn.Module): WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), - NoiseInjector(), + NoiseInjector(identifier=("attention", b)), ), QKVAttention( dim_in=dim_model, @@ -289,7 +297,7 @@ class MyGPT(nn.Module): WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), - NoiseInjector(), + NoiseInjector(identifier=("ffw", b)), nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), nn.Linear(in_features=dim_hidden, out_features=dim_model), diff --git a/quiz_machine.py b/quiz_machine.py index ca71c95..93e048d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -29,7 +29,6 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - logit_transformer=None, deterministic_synthesis=False, ): if input.size(0) == 0: @@ -46,9 +45,6 @@ def one_batch_masked_inplace_autoregression( logits = output[:, s] - if logit_transformer is not None: - logits = logit_transformer(s, logits).log_softmax(dim=-1) - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -107,7 +103,6 @@ class QuizMachine: input, ar_mask, seq_logproba=None, - logit_transformer=None, progress_bar_desc=None, ): assert input.size() == ar_mask.size() @@ -139,7 +134,6 @@ class QuizMachine: input=input, ar_mask=ar_mask, seq_logproba=seq_logproba, - logit_transformer=logit_transformer, deterministic_synthesis=False, ) @@ -484,12 +478,14 @@ class QuizMachine: return quiz.to("cpu") + ###################################################################### + def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): seq_logproba = torch.zeros(nb, device=self.device) c_quizzes = None - for s, m, t in procedure: + for s, m, mt in procedure: if c_quizzes is None: c_quizzes = self.problem.create_empty_quizzes(nb, s) c_quizzes = c_quizzes.to(self.device) @@ -497,14 +493,17 @@ class QuizMachine: c_quizzes = self.problem.reconfigure(c_quizzes, s) pred_s = s + mt(model_for_generation) + self.autoregression( model=model_for_generation, input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, s, m), seq_logproba=seq_logproba, - logit_transformer=t, ) + model_for_generation.reset_transformation() + if to_recycle is not None and to_recycle.size(0) > 0: to_recycle = self.problem.reconfigure(to_recycle, s) c_quizzes[: to_recycle.size(0)] = to_recycle @@ -516,153 +515,3 @@ class QuizMachine: return c_quizzes.to("cpu") ###################################################################### - - def generate_c_quizzes_orig( - self, - nb, - model_for_generation, - temperature_hot=1.0, - temperature_cold=1.0, - to_recycle=None, - ): - c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")) - c_quizzes = c_quizzes.to(self.device) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - lt_clean = lambda s, logits: logits / temperature_cold - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - ) - - if to_recycle is not None: - l = c_quizzes.size(1) // 4 - self.logger(f"recycling {to_recycle.size(0)} rejected quizzes") - c_quizzes[: to_recycle.size(0), :l] = to_recycle[:, 3 * l :] - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - ) - - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - ) - - return c_quizzes.to("cpu") - - ###################################################################### - - def generate_c_quizzes_( - self, - nb, - model_for_generation, - temperature_hot=1.0, - temperature_cold=1.0, - ): - warnings.warn( - "**************************** simple quiz generation", RuntimeWarning - ) - - c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) - c_quizzes = c_quizzes.to(self.device) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - ) - - return c_quizzes.to("cpu") - - ###################################################################### - - def generate_c_quizzes_2( - self, - nb, - model_for_generation, - temperature_hot=1.0, - temperature_cold=1.0, - ): - warnings.warn( - "**************************** simple quiz generation", RuntimeWarning - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - lt_clean = lambda s, logits: logits / temperature_cold - - c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) - c_quizzes = c_quizzes.to(self.device) - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 0, 0) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - ) - - c_quizzes2 = self.problem.create_empty_quizzes(nb, ("B", "f_B", "A", "f_A")) - c_quizzes2 = c_quizzes2.to(self.device) - - self.autoregression( - model=model_for_generation, - input=c_quizzes2, - ar_mask=self.make_ar_mask( - c_quizzes2, - ("B", "f_B", "A", "f_A"), - (1, 0, 0, 0), - ), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - ) - - l = c_quizzes.size(1) // 4 - c_quizzes[:, 2 * l : 3 * l] = c_quizzes2[:, :l] - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_ar_mask( - c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - ) - - return c_quizzes.to("cpu") - - ###################################################################### -- 2.39.5