Oups
[picoclvr.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math
14
15 import torch
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 ######################################################################
21
22 # A BracketedSequence is a BxTx... tensor with a first and a nb time
23 # steps to compute.
24
25 # Modules able to process it expect that they will have to process a
26 # first bracket starting at t=0, followed by a succession of brackets
27 # that move forward in time, do not overlap, and cover the axis T with
28 # no holes.
29 #
30 # Although it is more general, for a classical prompt-conditioned
31 # auto-regressive process it will be a first bracket starting at 0 and
32 # of arbitrary length for the "prompt", followed by brackets of length
33 # 1 for the successive tokens.
34 #
35 # Modules able to process brackets may implement a cache that is
36 # resetted when the input bracket starts at t=0
37
38
39 class BracketedSequence:
40     def __init__(self, x, first=None, nb=None):
41         self.x = x
42         self.first = 0 if first is None else first
43         self.nb = x.size(1) if nb is None else nb
44
45     def slice(self):
46         return self.x[:, self.first : self.first + self.nb]
47
48     def complete(self):
49         return self.first == 0 and self.nb == self.x.size(1)
50
51
52 ######################################################################
53
54
55 class CacheWrapper(nn.Module):
56     def __init__(self, *f):
57         super().__init__()
58         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
59
60     def forward(self, bs):
61         if bs.first == 0:
62             y = self.f(bs.slice())
63             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
64             self.cache_y[:, bs.first : bs.first + bs.nb] = y
65         else:
66             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
67
68         return BracketedSequence(self.cache_y, bs.first, bs.nb)
69
70
71 ##############################
72
73
74 class WithResidual(nn.Module):
75     def __init__(self, *f):
76         super().__init__()
77         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
78
79     def forward(self, bs):
80         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
81
82
83 ##############################
84
85
86 class AddPositionalEncoding(nn.Module):
87     def __init__(self, len_max):
88         super().__init__()
89         self.len_max = len_max
90
91     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
92
93     def forward(self, bs):
94         if bs.first == 0:
95             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
96                 :, None
97             ]
98             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
99                 None, :
100             ]
101             k = j % 2
102             self.pe = torch.sin(
103                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
104             )
105             self.cache_y = bs.x.new(bs.x.size())
106
107         self.cache_y[:, bs.first : bs.first + bs.nb] = (
108             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
109         )
110
111         return BracketedSequence(self.cache_y, bs.first, bs.nb)
112
113
114 ##############################
115
116
117 class QKVAttention(nn.Module):
118     def __init__(
119         self,
120         dim_in,
121         dim_qk,
122         dim_v,
123         nb_heads=1,
124         causal=False,
125         attention_dropout=0.0,
126     ):
127         super().__init__()
128
129         def randw(*d):
130             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
131
132         self.causal = causal
133         self.attention_dropout = attention_dropout
134         self.record_attention = False
135
136         self.w_q = randw(nb_heads, dim_qk, dim_in)
137         self.w_k = randw(nb_heads, dim_qk, dim_in)
138         self.w_v = randw(nb_heads, dim_v, dim_in)
139         self.w_o = randw(dim_v * nb_heads, dim_in)
140
141     def forward(self, bs_q):
142         x_q = bs_q.x
143
144         assert (
145             self.causal or bs_q.complete()
146         ), "Partial evaluation is only possible for causal models"
147
148         if bs_q.first == 0:
149             self.cache_k = x_q.new_zeros(
150                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
151             )
152             self.cache_v = x_q.new_zeros(
153                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
154             )
155             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
156
157         q = torch.einsum(
158             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
159         )
160
161         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
162             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
163         )
164         self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
165             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
166         )
167
168         a = torch.einsum(
169             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
170         ) / math.sqrt(self.w_q.size(1))
171
172         if self.causal:
173             if bs_q.first == 0:
174                 self.cache_attzero = (
175                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
176                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
177                 )
178             a = a.masked_fill(
179                 self.cache_attzero[
180                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
181                 ],
182                 float("-inf"),
183             )
184
185         a = a.softmax(dim=3)
186
187         if self.record_attention:
188             self.a = a
189
190         a = F.dropout(a, self.attention_dropout, self.training)
191
192         y = torch.einsum(
193             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
194         ).flatten(2)
195
196         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
197
198         return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
199
200
201 ##############################
202
203
204 class MyGPT(nn.Module):
205     def __init__(
206         self,
207         vocabulary_size,
208         dim_model,
209         dim_keys,
210         dim_hidden,
211         nb_heads,
212         nb_blocks,
213         causal=False,
214         dropout=0.0,
215         len_max=1e5,
216     ):
217         super().__init__()
218
219         assert dim_model % nb_heads == 0
220
221         self.embedding = nn.Sequential(
222             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
223             AddPositionalEncoding(len_max),
224         )
225
226         trunk_blocks = []
227
228         for b in range(nb_blocks):
229             trunk_blocks += [
230                 WithResidual(
231                     CacheWrapper(nn.LayerNorm((dim_model,))),
232                     QKVAttention(
233                         dim_in=dim_model,
234                         dim_qk=dim_keys,
235                         dim_v=dim_model // nb_heads,
236                         nb_heads=nb_heads,
237                         causal=causal,
238                         attention_dropout=dropout,
239                     ),
240                 ),
241                 WithResidual(
242                     CacheWrapper(
243                         nn.LayerNorm((dim_model,)),
244                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
245                         nn.ReLU(),
246                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
247                         nn.Dropout(dropout),
248                     ),
249                 ),
250             ]
251
252         self.trunk = nn.Sequential(*trunk_blocks)
253
254         self.readout = CacheWrapper(
255             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
256         )
257
258         with torch.no_grad():
259             for m in self.modules():
260                 if isinstance(m, nn.Embedding):
261                     m.weight.normal_(mean=0, std=2e-2)
262                 elif isinstance(m, nn.LayerNorm):
263                     m.bias.zero_()
264                     m.weight.fill_(1.0)
265
266     def forward(self, bs):
267         # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
268         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
269         bs = self.embedding(bs)
270         bs = self.trunk(bs)
271         bs = self.readout(bs)
272         return bs
273
274     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
275     # 1s where tokens should be generated. The others are kept
276     # unchanged.
277
278     def masked_inplace_autoregression(
279         self,
280         input,
281         ar_mask,
282         deterministic_synthesis=False,
283         forbidden_tokens=None,
284         forced_biases=None,
285     ):
286         to_generate = (ar_mask.sum(0) > 0).nonzero()
287         if to_generate.min() > 0:
288             self(
289                 BracketedSequence(input, 0, to_generate.min())
290             )  # Needed to initialize the model's cache
291         for s in range(to_generate.min(), to_generate.max() + 1):
292             output = self(BracketedSequence(input, s, 1)).x
293             logits = output[:, s]
294             if forbidden_tokens is not None:
295                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
296             if forced_biases is not None:
297                 logits = logits + forced_biases[None, :]
298             if deterministic_synthesis:
299                 t_next = logits.argmax(1)
300             else:
301                 dist = torch.distributions.categorical.Categorical(logits=logits)
302                 t_next = dist.sample()
303             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
304
305     def record_attention(self, v=True):
306         for m in self.modules():
307             if isinstance(m, QKVAttention):
308                 m.record_attention = v
309
310     def retrieve_attention(self):
311         a = []
312         for m in self.modules():
313             if isinstance(m, QKVAttention):
314                 a.append(m.a)
315         return a
316
317
318 ######################################################################
319
320 if __name__ == "__main__":
321     print("Basic check.")
322
323     vocabulary_size = 3
324     x = torch.randint(vocabulary_size, (1, 5))
325
326     model = MyGPT(
327         vocabulary_size=vocabulary_size,
328         dim_model=4,
329         dim_keys=2,
330         dim_hidden=2,
331         nb_heads=2,
332         nb_blocks=2,
333         dropout=0.1,
334         causal=True,
335     )
336
337     model.eval()
338     y1 = model(BracketedSequence(x)).x
339     y2 = torch.randn_like(y1)
340     for s in range(x.size(1)):
341         z = model(BracketedSequence(x, s, 1))
342         y2[:, s] = z.slice()
343
344     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
345
346 ######################################################################