From fdd573490e517d38fb0477ae1b5df12b74718d45 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 3 Apr 2021 12:36:18 +0200 Subject: [PATCH] Simplified with Einstein summations. --- attentiontoy1d.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 2cecad8..d389f0c 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -7,7 +7,7 @@ import torch, math, sys, argparse -from torch import nn +from torch import nn, einsum from torch.nn import functional as F import matplotlib.pyplot as plt @@ -190,9 +190,9 @@ class AttentionLayer(nn.Module): Q = self.conv_Q(x) K = self.conv_K(x) V = self.conv_V(x) - A = Q.permute(0, 2, 1).matmul(K).softmax(2) - x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1) - return x + A = einsum('nct,ncs->nts', Q, K).softmax(2) + y = einsum('nts,ncs->nct', A, V) + return y def __repr__(self): return self._get_name() + \ @@ -205,7 +205,8 @@ class AttentionLayer(nn.Module): def attention(self, x): Q = self.conv_Q(x) K = self.conv_K(x) - return Q.permute(0, 2, 1).matmul(K).softmax(2) + A = einsum('nct,ncs->nts', Q, K).softmax(2) + return A ###################################################################### -- 2.39.5