import torch, math, sys, argparse
import torch, math, sys, argparse
class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
- A = Q.permute(0, 2, 1).matmul(K).softmax(2)
- x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
- return x
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ y = einsum('nts,ncs->nct', A, V)
+ return y
test_input = test_input.detach().to('cpu')
test_outputs = test_outputs.detach().to('cpu')
test_targets = test_targets.detach().to('cpu')
test_input = test_input.detach().to('cpu')
test_outputs = test_outputs.detach().to('cpu')
test_targets = test_targets.detach().to('cpu')