import torch, math, sys, argparse
-from torch import nn
+from torch import nn, einsum
from torch.nn import functional as F
import matplotlib.pyplot as plt
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
- A = Q.permute(0, 2, 1).matmul(K).softmax(2)
- x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
- return x
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ y = einsum('nts,ncs->nct', A, V)
+ return y
def __repr__(self):
return self._get_name() + \
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
- return Q.permute(0, 2, 1).matmul(K).softmax(2)
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ return A
######################################################################