From 933c9a500f2918635341a7b858172140c4535dee Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 12 Oct 2024 09:39:36 +0200 Subject: [PATCH] Update. --- attae.py | 39 +++++++++++++++++++++++++++++++++++++++ main.py | 50 ++++++++++++++++++++++++++++++++++---------------- 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/attae.py b/attae.py index bb97ed4..94da984 100755 --- a/attae.py +++ b/attae.py @@ -383,6 +383,45 @@ class Reasoning(nn.Module): residual_masker=residual_masker, ) + self.mha_A = MHAttention( + dim_model=dim_model, + dim_qk=dim_qk, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=attention_dropout, + ) + + self.mha_B = MHAttention( + dim_model=dim_model, + dim_qk=dim_qk, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=attention_dropout, + ) + + def forward_AB(self, x_q): + T, S = x_q.size(1), self.x_star.size(0) + nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks + + x = x_q + x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) + x = self.trunk_A(x) + f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim) + f = self.mha_A(f, x) + + k = torch.arange(nb, device=x_q.device) + u = f[k * 2, :] + f[k * 2, :] = f[k * 2 + 1, :] + f[k * 2 + 1, :] = u + + f = self.mha_B(x, f) + x = self.trunk_B(x) + x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim) + + return x + def forward(self, x_q): T, S = x_q.size(1), self.x_star.size(0) nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks diff --git a/main.py b/main.py index 3b1caa9..c8d6f10 100755 --- a/main.py +++ b/main.py @@ -567,6 +567,9 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): if nb_samples % args.batch_size == 0: model.optimizer.step() + if train: + model.nb_epochs += 1 + log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") @@ -941,7 +944,7 @@ if args.test == "aebn": pe, # trainable=True ) - nb_f_tokens = 100 + nb_f_tokens = 200 def no_f_residual(x): m = x.new_full((1, x.size(1), 1), 1.0) @@ -964,24 +967,39 @@ if args.test == "aebn": model.test_accuracy = 0.0 model.nb_epochs = 0 - for n_epoch in range(args.nb_epochs): - one_complete_epoch( - model, - n_epoch, - train_c_quizzes=None, - test_c_quizzes=None, - local_device=main_device, - ) + if args.resume: filename = f"aebn_{model.id:03d}.pth" - torch.save( - { - "state_dict": model.state_dict(), - "optimizer_state_dict": model.optimizer.state_dict(), - "test_accuracy": model.test_accuracy, - "nb_epochs": model.nb_epochs, - }, + + d = torch.load( os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, ) + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.test_accuracy = d["test_accuracy"] + model.nb_epochs = d["nb_epochs"] + log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}") + + else: + for n_epoch in range(args.nb_epochs): + one_complete_epoch( + model, + n_epoch, + train_c_quizzes=None, + test_c_quizzes=None, + local_device=main_device, + ) + filename = f"aebn_{model.id:03d}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + "nb_epochs": model.nb_epochs, + }, + os.path.join(args.result_dir, filename), + ) exit(0) -- 2.39.5