self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(nb_heads, dim_in, dim_v)
self.causal = causal
self.attention_dropout = attention_dropout
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
- mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
+ mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
+ < torch.arange(a.size(3), device = q.device)[None, None, None, :]
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum('nhts,nhsd->nhtd', a, v)
- return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+ y = torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+
+ return y
##############################
######################################################################
if __name__ == '__main__':
+ print('Basic check.')
+
vocabulary_size = 10
x = torch.randint(vocabulary_size, (25, 100))