# [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
- def forward(self, bs):
+ def forward(self, bs, order): # NxTxD, T
if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ t = (
+ torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[
+ :, None
+ ]
+ - 1
+ )
+ j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[
None, :
]
k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ pe = (
+ torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ .unsqueeze(0)
+ .expand(bs.x.size(0), -1, -1)
)
+
+ order_output = order + 1
+ order_input = F.pad(order + 1, (1, -1))
+
+ pe_input = pe.gather(
+ 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+ pe_output = pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+
+ self.pe = torch.cat((pe_input, pe_output), 2)
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
)
bs.x = self.cache_y
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
+ amm_generator=None,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+ if amm_generator is None:
+ self.amm_generator = (
+ lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :]
+ )
+ else:
+ self.amm_generator = amm_generator
+
self.causal = causal
self.attention_dropout = attention_dropout
if self.causal:
if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
+ self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[
+ None, None, :, :
+ ]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
causal=False,
dropout=0.0,
len_max=1e5,
+ amm_generator=None,
):
super().__init__()
assert dim_model % nb_heads == 0
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
+ self.embedding = CacheWrapper(
+ nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
)
+ self.pe = AddPositionalEncoding(len_max)
trunk_blocks = []
nb_heads=nb_heads,
causal=causal,
attention_dropout=dropout,
+ amm_generator=amm_generator,
),
),
WithResidual(
m.bias.zero_()
m.weight.fill_(1.0)
- def forward(self, bs, mode="standard"):
- bs.x = F.pad(bs.x, (1, -1))
+ def forward(self, bs, mode="standard", order=None):
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ if order is None:
+ order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as(
+ bs.x
+ )
bs = self.embedding(bs)
+ bs = self.pe(bs, order)
+
if mode == "standard":
bs = self.trunk(bs)
bs = self.readout(bs)
r += [bs.slice()]
bs = BracketedSequence(torch.cat(r, -1))
else:
- raise ValueError
+ raise ValueError(f"{mode=}")
return bs