--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import flex_attention
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+ def forward(self, x):
+ t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+ j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+ k = j % 2
+
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+
+ y = x + pe
+
+ return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, x):
+ return x + self.f(x)
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.attention_dropout = attention_dropout
+ self.record_attention = False
+
+ self.w_q = randw(nb_heads, dim_qk, dim_in)
+ self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(nb_heads, dim_v, dim_in)
+
+ def forward(self, x_q, x_kv=None):
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
+ v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+
+ y = flex_attention(q, k, v)
+
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1024,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x, mask=None):
+ x = self.embedding(x)
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+if __name__ == "__main__":
+ model = AttentionAE(
+ vocabulary_size=100,
+ dim_model=16,
+ dim_keys=64,
+ dim_hidden=32,
+ nb_heads=4,
+ nb_blocks=4,
+ dropout=0.1,
+ )
+
+ x = torch.randint(100, (10, 50))
+
+ y = model(x)