From d6b0a90da3db5ed5c5bea521a8a5a85d03fca725 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 8 Sep 2024 09:44:33 +0200 Subject: [PATCH] Update. --- attae.py | 170 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100755 attae.py diff --git a/attae.py b/attae.py new file mode 100755 index 0000000..3a9f105 --- /dev/null +++ b/attae.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python + +import math + +import torch + +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import flex_attention + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 + + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + + y = x + pe + + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention_dropout = attention_dropout + self.record_attention = False + + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(nb_heads, dim_v, dim_in) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k) + v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v) + + y = flex_attention(q, k, v) + + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1024, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x, mask=None): + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x) + return x + + +###################################################################### + + +if __name__ == "__main__": + model = AttentionAE( + vocabulary_size=100, + dim_model=16, + dim_keys=64, + dim_hidden=32, + nb_heads=4, + nb_blocks=4, + dropout=0.1, + ) + + x = torch.randint(100, (10, 50)) + + y = model(x) -- 2.39.5