- self.w_o = randw(dim_in, dim_v * nb_heads)
- self.causal = causal
- self.attention_dropout = attention_dropout
+ self.w_o = randw(dim_v * nb_heads, dim_in)
+
+ def forward(self, x_q, x_kv=None):
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
+ v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))