q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)