X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=single-attention.tex;h=d21595ee65aa3b00660982bf019b6782034e7de9;hb=070904ec73b192aff94b8f2e1db3a9fd0d80c320;hp=01a181c861be91739f0111693feda5253f7472e1;hpb=6e541e7102264b99f1a4aa72325a2b4b81fcb3eb;p=tex.git diff --git a/single-attention.tex b/single-attention.tex index 01a181c..d21595e 100644 --- a/single-attention.tex +++ b/single-attention.tex @@ -113,4 +113,70 @@ Single-head attention \end{frame} +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +\begin{frame} + +\begin{center} + +\begin{tikzpicture} + +\node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$}; +\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east); +\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west); + +\node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=1cm of K] {$Q$}; +\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east); +\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west); + +\node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=1cm of K] {$V$}; +\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east); +\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west); + +\node[operation,minimum height=0.4cm] (mulWq) [left=1cm of Q.center] {$\cdot$}; +\node[operation,minimum height=0.4cm] (mulWk) [left=1cm of K.center] {$\cdot$}; +\node[operation,minimum height=0.4cm] (mulWv) [left=1cm of V.center] {$\cdot$}; + +\node[value, minimum height=1.2cm,minimum width=0.5cm] (X) [left=1cm of mulWq] {$X$}; +\draw[very thick,cyan] ([xshift=-1pt]X.north west) -- ([xshift=-1pt]X.south west); +\draw[very thick,green] ([yshift=1pt]X.north west) -- ([yshift=1pt]X.north east); + +\node[parameter, minimum height=0.5cm,minimum width=0.7cm] (Wq) [above=0.25 cm of X] {$W^Q$}; +\draw[very thick,green] ([xshift=-1pt]Wq.north west) -- ([xshift=-1pt]Wq.south west); +\draw[very thick,yellow] ([yshift=1pt]Wq.north west) -- ([yshift=1pt]Wq.north east); + +\node[value, minimum height=0.8cm,minimum width=0.3cm] (X') [below=1.2cm of X] {$X'$}; + +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$}; +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$}; + +\node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$}; +\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west); +\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east); + +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$}; + +\node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$}; +\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east); +\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west); + +\draw[v2f,rounded corners=1mm] (X) -- (mulWq); +\draw[v2f,rounded corners=1mm] (Wq) -| (mulWq); +\draw[f2v,rounded corners=1mm] (mulWq) -- ([xshift=-1pt]Q.west); + +\draw[v2f,rounded corners=1mm] (K) -- (att); +\draw[v2f,rounded corners=1mm] (Q) -| (att); +\draw[f2f,rounded corners=1mm] (att) -- (sm); +\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west); + +\draw[v2f,rounded corners=1mm] (A) -- (prod); +\draw[v2f,rounded corners=1mm] (V) -| (prod); +\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west); + +\end{tikzpicture} + +\end{center} + +\end{frame} + \end{document}