From b39bcde2b91433b895bb11493ab5d0143c5055f5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 9 Oct 2024 10:53:32 +0200 Subject: [PATCH] Update. --- attae.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 43 +++++++++++++++++++++----- 2 files changed, 127 insertions(+), 7 deletions(-) diff --git a/attae.py b/attae.py index 2b231de..c386e3a 100755 --- a/attae.py +++ b/attae.py @@ -286,6 +286,97 @@ class FunctionalAttentionAE(nn.Module): ###################################################################### +class Reasoning(nn.Module): + def __init__( + self, + nb_f_tokens, + nb_chunks, + dim_model, + dim_qk, + dim_hidden, + nb_heads=1, + nb_blocks=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.nb_chunks = nb_chunks + self.x_star = randw(nb_f_tokens, dim_model) + + self.trunk_joint = create_trunk( + dim_model=dim_model, + dim_keys=dim_qk, + dim_hidden=dim_hidden, + nb_heads=nb_heads, + nb_blocks=nb_blocks, + dropout=attention_dropout, + ) + + self.trunk_marginal = create_trunk( + dim_model=dim_model, + dim_keys=dim_qk, + dim_hidden=dim_hidden, + nb_heads=nb_heads, + nb_blocks=nb_blocks, + dropout=attention_dropout, + ) + + def forward(self, x_q): + T, S = x_q.size(1), self.x_star.size(0) + nb, dim = x_q.size(0), x_q.size(2) + x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1) + + x = torch.cat([x_star, x_q], dim=1) + x = self.trunk_joint(x) + + f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :] + x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim) + f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim) + x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim) + f = f.reshape(nb * self.nb_chunks, S, dim) + x = torch.cat([f, x], dim=1) + x = self.trunk_marginal(x) + + x = x[:, x_star.size(1) :, :] + x = x.reshape(nb, T, dim) + + return x + + +###################################################################### + + +class LearningToBeMe(nn.Module): + def __init__(self, real_f, mimic_f, epsilon): + super().__init__() + self.real_f = real_f + self.mimic_f = mimic_f + self.epsilon = epsilon + + def forward(self, x): + y_real_f = self.real_f(x) + y_mimic_f = self.mimic_f(x.detach()) + z_real_f = y_real_f.flatten(1) + z_mimic_f = y_mimic_f.flatten(1) + e = ( + (z_real_f.detach() - z_mimic_f).pow(2).sum(dim=1, keepdim=True) + / z_real_f.detach().pow(2).sum(dim=1, keepdim=True) + ).sqrt() + self.error = e.mean() + m = (e <= self.epsilon).float() + self.nb_me = (1 - m).mean() + z = (1 - m) * z_real_f + m * z_mimic_f + y = z.reshape(y_real_f.shape) + return y + + +###################################################################### + + if __name__ == "__main__": model = FunctionalAttentionAE( vocabulary_size_in=100, diff --git a/main.py b/main.py index f6cf450..41862a2 100755 --- a/main.py +++ b/main.py @@ -557,6 +557,14 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): logits.transpose(1, 2), targets, reduction="none" ) loss = (loss_per_token * masks).mean() + + if args.test == "aebn": + error = 0 + for m in model.modules(): + if hasattr(m, "error"): + error = error + m.error + loss = loss + error + acc_loss += loss.item() * imt.size(0) nb_samples += imt.size(0) @@ -566,6 +574,14 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): if nb_samples % args.batch_size == 0: model.optimizer.step() + if args.test == "aebn": + nb_me = [] + for m in model.modules(): + if hasattr(m, "nb_me"): + nb_me.append(m.nb_me.item()) + + log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}") + log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") @@ -913,16 +929,27 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### if args.test == "aebn": - model = new_model() + model = new_model(0) + # f = model.trunk[len(model.trunk) // 2 :] + model.trunk = attae.Reasoning( + nb_f_tokens=250, + nb_chunks=2, + dim_model=args.dim_model, + dim_qk=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks // 2, + attention_dropout=args.dropout, + ) - # model.trunk = ( - # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :] + # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential( + # attae.LearningToBeMe(f, g, 1e-1) # ) - model.id = 0 - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model.test_accuracy = 0.0 - model.nb_epochs = 0 + # model.id = 0 + # model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + # model.test_accuracy = 0.0 + # model.nb_epochs = 0 for n_epoch in range(args.nb_epochs): one_complete_epoch( @@ -933,6 +960,8 @@ if args.test == "aebn": local_device=main_device, ) + exit(0) + ###################################################################### train_c_quizzes, test_c_quizzes = None, None -- 2.39.5