Update.
[tex.git] / single-attention.tex
1 % -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; mode: yas/minor; coding: utf-8; tex-command: "pdflatex.sh" -*-
2
3 % Any copyright is dedicated to the Public Domain.
4 % https://creativecommons.org/publicdomain/zero/1.0/
5
6 % Written by Francois Fleuret <francois@fleuret.org>
7
8 \documentclass[c,8pt]{beamer}
9
10 \setbeamertemplate{navigation symbols}{}
11
12 \def\transpose{^{\top}}
13 \def\softmax{\operatorname{softmax}}
14
15 \definecolor{blue}{rgb}{0.0,0.0,0.55}
16 \definecolor{green}{rgb}{0.0,0.50,0.0}
17 \definecolor{bluegray}{rgb}{0.1,0.2,0.7}
18
19 \setbeamercolor{math text}{fg=bluegray}
20 \setbeamercolor{local structure}{fg=blue}
21
22 \usepackage{tikz}
23
24 \usetikzlibrary{positioning,fit,backgrounds}
25 \usetikzlibrary{arrows.meta,decorations.pathreplacing}
26 \usetikzlibrary{calc}
27 \usetikzlibrary{shapes,calc,intersections}
28 \usetikzlibrary{patterns}
29
30 \usetikzlibrary{arrows}
31
32 \definecolor{nn-data}   {rgb}{0.90, 0.95, 1.00}
33 \definecolor{nn-param}  {rgb}{1.00, 0.90, 0.50}
34 \definecolor{nn-process}{rgb}{0.80, 1.00, 0.80}
35 \tikzset{>={Straight Barb[angle'=80,scale=1.1]}}
36
37 \tikzset{
38   value/.style    ={ font=\scriptsize, rectangle, draw=black!50, fill=white,   thick,
39                      inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
40   parameter/.style={ font=\scriptsize, rectangle, draw=black!50, fill=blue!15, thick,
41                      inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
42   operation/.style={ font=\scriptsize, rectangle,    draw=black!50, fill=green!30, thick,
43                      inner sep=3pt, minimum size=10pt, minimum height=20pt },
44   flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
45 %
46   f2f/.style={draw=black!50, thick},
47   v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
48   f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick},
49   v2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black!50, thick},
50 %
51 %
52   df2f/.style={draw=black, thick},
53   dv2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black, thick},
54   df2v/.style={->,shorten >= 0.75pt,draw=black, thick},
55   dv2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black, thick},
56 %
57   differential/.style    ={ font=\small, rectangle, draw=black!50,               thick,
58                      inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt, fill=yellow!80 },
59   dflow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black, thick}
60 }
61
62 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
63
64 \begin{document}
65
66 \begin{frame}
67
68 \begin{center}
69
70 \begin{tikzpicture}
71
72 \node[value,    minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
73 \node[value,    minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
74 \node[value,    minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
75 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
76 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
77 \node[value,    minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
78 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
79 \node[value,    minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
80
81 \draw[v2f,rounded corners=1mm] (K) -- (att);
82 \draw[v2f,rounded corners=1mm] (Q) -| (att);
83 \draw[f2f,rounded corners=1mm] (att) -- (sm);
84 \draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
85
86 \draw[v2f,rounded corners=1mm] (A) -- (prod);
87 \draw[v2f,rounded corners=1mm] (V) -| (prod);
88 \draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
89
90 \draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
91 \draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
92 \draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
93 \draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
94
95 \draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
96 \draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
97 \draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
98 \draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
99
100 \draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
101 \draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
102
103 \end{tikzpicture}
104
105 \begin{align*}
106 A & = \softmax_{row} \left( \frac{Q K\transpose}{\sqrt{D}} \right) \\
107 Y & = A V.
108 \end{align*}
109
110 Single-head attention
111
112 \end{center}
113
114 \end{frame}
115
116 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
117
118 \begin{frame}
119
120 \begin{center}
121
122 \begin{tikzpicture}
123
124 \node[value,    minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
125 \draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
126 \draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
127
128 \node[value,    minimum height=1.2cm,minimum width=0.7cm] (Q) [above=1cm of K] {$Q$};
129 \draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
130 \draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
131
132 \node[value,    minimum height=0.8cm,minimum width=1.0cm] (V) [below=1cm of K] {$V$};
133 \draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
134 \draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
135
136 \node[operation,minimum height=0.4cm] (mulWq) [left=1cm of Q.center] {$\cdot$};
137 \node[operation,minimum height=0.4cm] (mulWk) [left=1cm of K.center] {$\cdot$};
138 \node[operation,minimum height=0.4cm] (mulWv) [left=1cm of V.center] {$\cdot$};
139
140 \node[value,    minimum height=1.2cm,minimum width=0.5cm] (X) [left=1cm of mulWq] {$X$};
141 \draw[very thick,cyan] ([xshift=-1pt]X.north west) -- ([xshift=-1pt]X.south west);
142 \draw[very thick,green] ([yshift=1pt]X.north west) -- ([yshift=1pt]X.north east);
143
144 \node[parameter,    minimum height=0.5cm,minimum width=0.7cm] (Wq) [above=0.25 cm of X] {$W^Q$};
145 \draw[very thick,green] ([xshift=-1pt]Wq.north west) -- ([xshift=-1pt]Wq.south west);
146 \draw[very thick,yellow] ([yshift=1pt]Wq.north west) -- ([yshift=1pt]Wq.north east);
147
148 \node[value,    minimum height=0.8cm,minimum width=0.3cm] (X') [below=1.2cm of X]  {$X'$};
149
150 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
151 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
152
153 \node[value,    minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
154 \draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
155 \draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
156
157 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
158
159 \node[value,    minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
160 \draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
161 \draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
162
163 \draw[v2f,rounded corners=1mm] (X) -- (mulWq);
164 \draw[v2f,rounded corners=1mm] (Wq) -| (mulWq);
165 \draw[f2v,rounded corners=1mm] (mulWq) -- ([xshift=-1pt]Q.west);
166
167 \draw[v2f,rounded corners=1mm] (K) -- (att);
168 \draw[v2f,rounded corners=1mm] (Q) -| (att);
169 \draw[f2f,rounded corners=1mm] (att) -- (sm);
170 \draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
171
172 \draw[v2f,rounded corners=1mm] (A) -- (prod);
173 \draw[v2f,rounded corners=1mm] (V) -| (prod);
174 \draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
175
176 \end{tikzpicture}
177
178 \end{center}
179
180 \end{frame}
181
182 \end{document}