import torch, math, sys, argparse
-from torch import nn
+from torch import nn, einsum
from torch.nn import functional as F
import matplotlib.pyplot as plt
class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
- super(AttentionLayer, self).__init__()
+ super().__init__()
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
- A = Q.permute(0, 2, 1).matmul(K).softmax(2)
- x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
- return x
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ y = einsum('nts,ncs->nct', A, V)
+ return y
def __repr__(self):
return self._get_name() + \
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
- return Q.permute(0, 2, 1).matmul(K).softmax(2)
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ return A
######################################################################
test_input = test_input.detach().to('cpu')
test_outputs = test_outputs.detach().to('cpu')
test_targets = test_targets.detach().to('cpu')
+test_bx = test_bx.detach().to('cpu')
+test_tr = test_tr.detach().to('cpu')
for k in range(15):
save_sequence_images(