self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(nb_heads, dim_in, dim_v)
+ self.w_o = randw(dim_in, dim_v * nb_heads)
self.causal = causal
self.attention_dropout = attention_dropout
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
- y = torch.einsum('nhts,nhsd->nhtd', a, v)
- y = torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+ y = torch.einsum('nhts,nhsd->nthd', a, v)
+ y = y.flatten(2) @ self.w_o
return y