- q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
- k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
- v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
+ q = torch.einsum('ntc,hdc->nhtd', x, self.w_q)
+ k = torch.einsum('ntc,hdc->nhtd', x, self.w_k)
+ v = torch.einsum('ntc,hdc->nhtd', x, self.w_v)