t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
- return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+ return x + pe # Let broadcasting to its job
##############################
nn.LayerNorm(dim_model),
QKVAttention(
dim_in = dim_model,
- dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+ dim_qk = dim_keys,
+ dim_v = dim_model // nb_heads,
nb_heads = nb_heads,
causal = True, attention_dropout = dropout
),