From 39a1f9c00739e8902f233c1ac733e23fd34f808d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 11 Oct 2024 18:03:34 +0200 Subject: [PATCH] Update. --- attae.py | 119 +++++++++++++++++++++++++++++++------------------------ main.py | 34 ++++++++++++---- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/attae.py b/attae.py index 0d36a33..bb97ed4 100755 --- a/attae.py +++ b/attae.py @@ -76,12 +76,17 @@ class AdHocPositionalEncoding(nn.Module): class WithResidual(nn.Module): - def __init__(self, *f): + def __init__(self, f, masker=None): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker def forward(self, x): - return x + self.f(x) + if self.masker is None: + mask = 1 + else: + mask = self.masker(x) + return mask * x + self.f(x) ###################################################################### @@ -135,28 +140,42 @@ class MHAttention(nn.Module): ###################################################################### -def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0): +def create_trunk( + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + residual_masker=None, +): trunk_blocks = [] for b in range(nb_blocks): trunk_blocks += [ WithResidual( - nn.LayerNorm((dim_model,)), - MHAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention=vanilla_attention, - attention_dropout=dropout, + masker=residual_masker, + f=( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), ), ), WithResidual( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), + masker=residual_masker, + f=( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), ), ] @@ -214,12 +233,12 @@ class AttentionAE(nn.Module): def forward(self, x): x = self.embedding(x) - warnings.warn("flipping order for symmetry check", RuntimeWarning) - x = torch.cat([x[:, 200:], x[:, :200]], dim=1) + # warnings.warn("flipping order for symmetry check", RuntimeWarning) + x = self.positional_encoding(x) - x = torch.cat([x[:, 200:], x[:, :200]], dim=1) x = self.trunk(x) + x = self.readout(x) return x @@ -228,22 +247,6 @@ class AttentionAE(nn.Module): ###################################################################### -class WithMaskedResidual(nn.Module): - def __init__(self, masker, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - self.masker = masker - self.mask = None - - def forward(self, x): - if self.mask is None: - self.mask = self.masker(x) - return self.mask * x + self.f(x) - - -###################################################################### - - class FunctionalAttentionAE(nn.Module): def __init__( self, @@ -348,6 +351,7 @@ class Reasoning(nn.Module): attention=vanilla_attention, attention_dropout=0.0, len_max=1e5, + residual_masker=None, ): super().__init__() @@ -359,47 +363,58 @@ class Reasoning(nn.Module): self.positional_encoding = VaswaniPositionalEncoding(len_max) - self.trunk_joint = create_trunk( + self.trunk_A = create_trunk( dim_model=dim_model, dim_keys=dim_qk, dim_hidden=dim_hidden, nb_heads=nb_heads, nb_blocks=nb_blocks, dropout=attention_dropout, + residual_masker=residual_masker, ) - self.trunk_marginal = create_trunk( + self.trunk_B = create_trunk( dim_model=dim_model, dim_keys=dim_qk, dim_hidden=dim_hidden, nb_heads=nb_heads, nb_blocks=nb_blocks, dropout=attention_dropout, + residual_masker=residual_masker, ) def forward(self, x_q): - #!!!!!!!!!!!!!!!!!!!! - # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1) - T, S = x_q.size(1), self.x_star.size(0) nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks + f = self.x_star.reshape(1, S, dim).expand(nb, S, dim) - x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim) - - x = torch.cat([x_star, x_q], dim=1) - x = self.trunk_joint(x) - - f, x = x[:, :S, :], x[:, S:, :] - x = x.reshape(nb * nc, T // nc, dim) + x = x_q + x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) f = f.repeat(nc, 1, 1) x = torch.cat([f, x], dim=1) - x = self.trunk_marginal(x) + x = self.trunk_A(x) + k = torch.arange(nb, device=x_q.device) + u = x[k * 2, :S] + x[k * 2, :S] = x[k * 2 + 1, :S] + x[k * 2 + 1, :S] = u + x = self.trunk_B(x) + x = x[:, S:] + x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim) - x = x[:, S:, :] - x = x.reshape(nb, T, dim) + return x - #!!!!!!!!!!!!!!!!!!!! - # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1) + def forward_one_vector(self, x_q): + T, S = x_q.size(1), self.x_star.size(0) + nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks + x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim) + x = torch.cat([x_star, x_q], dim=1) + x = self.trunk_A(x) + f = x[:, :S, :] + x = x_q + x = x + 1e-3 * f.mean(dim=1, keepdim=True) + x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) + x = self.trunk_B(x) + x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim) return x diff --git a/main.py b/main.py index d5c1c5c..3b1caa9 100755 --- a/main.py +++ b/main.py @@ -527,15 +527,15 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): imt_set = torch.cat([b_p, b_g]) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] + batch_size = args.batch_size + if train: label = "train" model.train().to(local_device) optimizer_to(model.optimizer, local_device) - batch_size = args.train_batch_size else: label = "test" model.eval().to(local_device) - batch_size = args.eval_batch_size nb_samples, acc_loss = 0, 0.0 @@ -932,20 +932,30 @@ if args.test == "aebn": # ) i = torch.arange(400)[:, None] - k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200] + k = [1, 2, 4, 8, 16, 10, 20, 40, 80, 160, 100, 200] k = torch.tensor(k)[None, :] - pe = (i // k) % 2 + pe = 2.0 * ((i // k) % 2) - 1.0 + + model.positional_encoding = attae.AdHocPositionalEncoding( + args.dim_model, + pe, # trainable=True + ) + + nb_f_tokens = 100 - model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe) + def no_f_residual(x): + m = x.new_full((1, x.size(1), 1), 1.0) + m[:, :nb_f_tokens, :] = 0 + return m model.trunk = attae.Reasoning( - nb_f_tokens=8, + nb_f_tokens=nb_f_tokens, nb_chunks=2, dim_model=args.dim_model, dim_qk=args.dim_keys, dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks // 2, + nb_blocks=args.nb_blocks, attention_dropout=args.dropout, ) @@ -962,6 +972,16 @@ if args.test == "aebn": test_c_quizzes=None, local_device=main_device, ) + filename = f"aebn_{model.id:03d}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + "nb_epochs": model.nb_epochs, + }, + os.path.join(args.result_dir, filename), + ) exit(0) -- 2.39.5