Added figures
[culture.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math
14
15 import torch
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 ######################################################################
21
22 # A BracketedSequence is a BxTx... tensor with a first and a nb time
23 # steps to compute.
24
25 # Modules able to process it expect that they will have to process a
26 # first bracket starting at t=0, followed by a succession of brackets
27 # that move forward in time, do not overlap, and cover the axis T with
28 # no holes.
29 #
30 # Although it is more general, for a classical prompt-conditioned
31 # auto-regressive process it will be a first bracket starting at 0 and
32 # of arbitrary length for the "prompt", followed by brackets of length
33 # 1 for the successive tokens.
34 #
35 # Modules able to process brackets may implement a cache that is
36 # resetted when the input bracket starts at t=0
37
38
39 class BracketedSequence:
40     def __init__(self, x, first=None, nb=None):
41         self.x = x
42         self.first = 0 if first is None else first
43         self.nb = x.size(1) if nb is None else nb
44
45     def slice(self):
46         return self.x[:, self.first : self.first + self.nb]
47
48     def complete(self):
49         return self.first == 0 and self.nb == self.x.size(1)
50
51
52 ######################################################################
53
54
55 class CacheWrapper(nn.Module):
56     def __init__(self, *f):
57         super().__init__()
58         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
59
60     def forward(self, bs):
61         if bs.first == 0:
62             y = self.f(bs.slice())
63             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
64             self.cache_y[:, bs.first : bs.first + bs.nb] = y
65         else:
66             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
67
68         return BracketedSequence(self.cache_y, bs.first, bs.nb)
69
70
71 ##############################
72
73
74 class WithResidual(nn.Module):
75     def __init__(self, *f):
76         super().__init__()
77         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
78
79     def forward(self, bs):
80         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
81
82
83 ##############################
84
85
86 class AddPositionalEncoding(nn.Module):
87     def __init__(self, len_max):
88         super().__init__()
89         self.len_max = len_max
90
91     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
92
93     def forward(self, bs):
94         if bs.first == 0:
95             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
96                 :, None
97             ]
98             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
99                 None, :
100             ]
101             k = j % 2
102             self.pe = torch.sin(
103                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
104             )
105             self.cache_y = bs.x.new(bs.x.size())
106
107         self.cache_y[:, bs.first : bs.first + bs.nb] = (
108             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
109         )
110
111         return BracketedSequence(self.cache_y, bs.first, bs.nb)
112
113
114 ##############################
115
116
117 class QKVAttention(nn.Module):
118     def __init__(
119         self,
120         dim_in,
121         dim_qk,
122         dim_v,
123         nb_heads=1,
124         causal=False,
125         attention_dropout=0.0,
126     ):
127         super().__init__()
128
129         def randw(*d):
130             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
131
132         self.causal = causal
133         self.attention_dropout = attention_dropout
134         self.record_attention = False
135
136         self.w_q = randw(nb_heads, dim_qk, dim_in)
137         self.w_k = randw(nb_heads, dim_qk, dim_in)
138         self.w_v = randw(nb_heads, dim_v, dim_in)
139         self.w_o = randw(dim_v * nb_heads, dim_in)
140
141     def forward(self, bs_q):
142         x_q = bs_q.x
143
144         assert (
145             self.causal or bs_q.complete()
146         ), "Partial evaluation is only possible for causal models"
147
148         if bs_q.first == 0:
149             self.cache_k = x_q.new_zeros(
150                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
151             )
152             self.cache_v = x_q.new_zeros(
153                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
154             )
155             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
156
157         q = torch.einsum(
158             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
159         )
160
161         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
162             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
163         )
164         self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
165             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
166         )
167
168         a = torch.einsum(
169             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
170         ) / math.sqrt(self.w_q.size(1))
171
172         if self.causal:
173             if bs_q.first == 0:
174                 self.cache_attzero = (
175                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
176                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
177                 )
178             a = a.masked_fill(
179                 self.cache_attzero[
180                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
181                 ],
182                 float("-inf"),
183             )
184
185         a = a.softmax(dim=3)
186
187         if self.record_attention:
188             self.a = a
189
190         a = F.dropout(a, self.attention_dropout, self.training)
191
192         y = torch.einsum(
193             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
194         ).flatten(2)
195
196         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
197
198         return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
199
200
201 ##############################
202
203
204 class NoiseInjector(nn.Module):
205     def __init__(self):
206         super().__init__()
207         self.noise_std = 0.0
208
209     def forward(self, x):
210         if self.noise_std > 0:
211             x = x + torch.randn(x.size(), device=x.device) * self.noise_std
212         return x
213
214
215 def set_noise_injection(model, noise_std):
216     for m in model.modules():
217         if isinstance(m, NoiseInjector):
218             m.noise_std = noise_std
219
220
221 ##############################
222
223
224 class MyGPT(nn.Module):
225     def __init__(
226         self,
227         vocabulary_size,
228         dim_model,
229         dim_keys,
230         dim_hidden,
231         nb_heads,
232         nb_blocks,
233         causal=False,
234         dropout=0.0,
235         len_max=1e5,
236     ):
237         super().__init__()
238
239         assert dim_model % nb_heads == 0
240
241         self.embedding = nn.Sequential(
242             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
243             AddPositionalEncoding(len_max),
244         )
245
246         trunk_blocks = []
247
248         for b in range(nb_blocks):
249             trunk_blocks += [
250                 WithResidual(
251                     CacheWrapper(
252                         nn.LayerNorm((dim_model,)),
253                         NoiseInjector(),
254                     ),
255                     QKVAttention(
256                         dim_in=dim_model,
257                         dim_qk=dim_keys,
258                         dim_v=dim_model // nb_heads,
259                         nb_heads=nb_heads,
260                         causal=causal,
261                         attention_dropout=dropout,
262                     ),
263                 ),
264                 WithResidual(
265                     CacheWrapper(
266                         nn.LayerNorm((dim_model,)),
267                         NoiseInjector(),
268                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
269                         nn.ReLU(),
270                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
271                         nn.Dropout(dropout),
272                     ),
273                 ),
274             ]
275
276         self.trunk = nn.Sequential(*trunk_blocks)
277
278         self.readout = CacheWrapper(
279             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
280         )
281
282         with torch.no_grad():
283             for m in self.modules():
284                 if isinstance(m, nn.Embedding):
285                     m.weight.normal_(mean=0, std=2e-2)
286                 elif isinstance(m, nn.LayerNorm):
287                     m.bias.zero_()
288                     m.weight.fill_(1.0)
289
290     def forward(self, bs):
291         # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
292         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
293         bs = self.embedding(bs)
294         bs = self.trunk(bs)
295         bs = self.readout(bs)
296         return bs
297
298     def record_attention(self, v=True):
299         for m in self.modules():
300             if isinstance(m, QKVAttention):
301                 m.record_attention = v
302
303     def retrieve_attention(self):
304         a = []
305         for m in self.modules():
306             if isinstance(m, QKVAttention):
307                 a.append(m.a)
308         return a
309
310
311 ######################################################################
312
313 if __name__ == "__main__":
314     print("Basic check.")
315
316     vocabulary_size = 3
317     x = torch.randint(vocabulary_size, (1, 5))
318
319     model = MyGPT(
320         vocabulary_size=vocabulary_size,
321         dim_model=4,
322         dim_keys=2,
323         dim_hidden=2,
324         nb_heads=2,
325         nb_blocks=2,
326         dropout=0.1,
327         causal=True,
328     )
329
330     model.eval()
331     y1 = model(BracketedSequence(x)).x
332     y2 = torch.randn_like(y1)
333     for s in range(x.size(1)):
334         z = model(BracketedSequence(x, s, 1))
335         y2[:, s] = z.slice()
336
337     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
338
339 ######################################################################