projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
67d11b3
)
Simplified with Einstein summations.
author
Francois Fleuret
<francois@fleuret.org>
Sat, 3 Apr 2021 10:36:18 +0000
(12:36 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sat, 3 Apr 2021 10:36:18 +0000
(12:36 +0200)
attentiontoy1d.py
patch
|
blob
|
history
diff --git
a/attentiontoy1d.py
b/attentiontoy1d.py
index
2cecad8
..
d389f0c
100755
(executable)
--- a/
attentiontoy1d.py
+++ b/
attentiontoy1d.py
@@
-7,7
+7,7
@@
import torch, math, sys, argparse
import torch, math, sys, argparse
-from torch import nn
+from torch import nn
, einsum
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torch.nn import functional as F
import matplotlib.pyplot as plt
@@
-190,9
+190,9
@@
class AttentionLayer(nn.Module):
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
- A =
Q.permute(0, 2, 1).matmul(
K).softmax(2)
-
x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1
)
- return
x
+ A =
einsum('nct,ncs->nts', Q,
K).softmax(2)
+
y = einsum('nts,ncs->nct', A, V
)
+ return
y
def __repr__(self):
return self._get_name() + \
def __repr__(self):
return self._get_name() + \
@@
-205,7
+205,8
@@
class AttentionLayer(nn.Module):
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
- return Q.permute(0, 2, 1).matmul(K).softmax(2)
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ return A
######################################################################
######################################################################