Update.
[picoclvr.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math
14
15 import torch
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 ######################################################################
21
22 # A BracketedSequence is a BxTx... tensor with a first and a nb time
23 # steps to compute.
24
25 # Modules able to process it expect that they will have to process a
26 # first bracket starting at t=0, followed by a succession of brackets
27 # that move forward in time, do not overlap, and cover the axis T with
28 # no holes.
29 #
30 # Although it is more general, for a classical prompt-conditioned
31 # auto-regressive process it will be a first bracket starting at 0 and
32 # of arbitrary length for the "prompt", followed by brackets of length
33 # 1 for the successive tokens.
34 #
35 # Modules able to process brackets may implement a cache that is
36 # resetted when the input bracket starts at t=0
37
38
39 class BracketedSequence:
40     def __init__(self, x, first=None, nb=None):
41         self.x = x
42         self.first = 0 if first is None else first
43         self.nb = x.size(1) if nb is None else nb
44
45     def slice(self):
46         return self.x[:, self.first : self.first + self.nb]
47
48
49 ######################################################################
50
51
52 class CacheWrapper(nn.Module):
53     def __init__(self, *f):
54         super().__init__()
55         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
56
57     def forward(self, bs):
58         if bs.first == 0:
59             y = self.f(bs.slice())
60             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
61             self.cache_y[:, bs.first : bs.first + bs.nb] = y
62         else:
63             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
64
65         return BracketedSequence(self.cache_y, bs.first, bs.nb)
66
67
68 ##############################
69
70
71 class WithResidual(nn.Module):
72     def __init__(self, *f):
73         super().__init__()
74         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
75
76     def forward(self, bs):
77         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
78
79
80 ##############################
81
82
83 class AddPositionalEncoding(nn.Module):
84     def __init__(self, len_max):
85         super().__init__()
86         self.len_max = len_max
87
88     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
89
90     def forward(self, bs):
91         if bs.first == 0:
92             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
93                 :, None
94             ]
95             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
96                 None, :
97             ]
98             k = j % 2
99             self.pe = torch.sin(
100                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
101             )
102             self.cache_y = bs.x.new(bs.x.size())
103
104         self.cache_y[:, bs.first : bs.first + bs.nb] = (
105             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
106         )
107
108         return BracketedSequence(self.cache_y, bs.first, bs.nb)
109
110
111 ##############################
112
113
114 class QKVAttention(nn.Module):
115     def __init__(
116         self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
117     ):
118         super().__init__()
119
120         def randw(*d):
121             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
122
123         assert causal, "TODO: Switch off the cache when non-causal!!!"
124         self.causal = causal
125         self.attention_dropout = attention_dropout
126
127         self.w_q = randw(nb_heads, dim_qk, dim_in)
128         self.w_k = randw(nb_heads, dim_qk, dim_in)
129         self.w_v = randw(nb_heads, dim_v, dim_in)
130         self.w_o = randw(dim_v * nb_heads, dim_in)
131
132     def forward(self, bs_q):
133         x_q = bs_q.x
134
135         if bs_q.first == 0:
136             self.cache_k = x_q.new_zeros(
137                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
138             )
139             self.cache_v = x_q.new_zeros(
140                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
141             )
142             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
143
144         q = torch.einsum(
145             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
146         )
147
148         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
149             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
150         )
151         self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
152             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
153         )
154
155         a = torch.einsum(
156             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
157         ) / math.sqrt(self.w_q.size(1))
158
159         if self.causal:
160             if bs_q.first == 0:
161                 self.cache_attzero = (
162                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
163                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
164                 )
165             a = a.masked_fill(
166                 self.cache_attzero[
167                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
168                 ],
169                 float("-inf"),
170             )
171
172         a = a.softmax(dim=3)
173         a = F.dropout(a, self.attention_dropout, self.training)
174
175         y = torch.einsum(
176             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
177         ).flatten(2)
178
179         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
180
181         return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
182
183
184 ##############################
185
186
187 class MyGPT(nn.Module):
188     def __init__(
189         self,
190         vocabulary_size,
191         dim_model,
192         dim_keys,
193         dim_hidden,
194         nb_heads,
195         nb_blocks,
196         causal=False,
197         dropout=0.0,
198         len_max=1e5,
199     ):
200         super().__init__()
201
202         assert dim_model % nb_heads == 0
203
204         self.embedding = nn.Sequential(
205             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
206             AddPositionalEncoding(len_max),
207         )
208
209         trunk_blocks = []
210
211         for b in range(nb_blocks):
212             trunk_blocks += [
213                 WithResidual(
214                     CacheWrapper(nn.LayerNorm((dim_model,))),
215                     QKVAttention(
216                         dim_in=dim_model,
217                         dim_qk=dim_keys,
218                         dim_v=dim_model // nb_heads,
219                         nb_heads=nb_heads,
220                         causal=causal,
221                         attention_dropout=dropout,
222                     ),
223                 ),
224                 WithResidual(
225                     CacheWrapper(
226                         nn.LayerNorm((dim_model,)),
227                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
228                         nn.ReLU(),
229                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
230                         nn.Dropout(dropout),
231                     ),
232                 ),
233             ]
234
235         self.trunk = nn.Sequential(*trunk_blocks)
236
237         self.readout = CacheWrapper(
238             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
239         )
240
241         with torch.no_grad():
242             for m in self.modules():
243                 if isinstance(m, nn.Embedding):
244                     m.weight.normal_(mean=0, std=2e-2)
245                 elif isinstance(m, nn.LayerNorm):
246                     m.bias.zero_()
247                     m.weight.fill_(1.0)
248
249     def forward(self, bs):
250         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
251         bs = self.embedding(bs)
252         bs = self.trunk(bs)
253         bs = self.readout(bs)
254         return bs
255
256     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
257     # 1s where tokens should be generated. The others are kept
258     # unchanged.
259
260     def masked_inplace_autoregression(
261         self, input, ar_mask, forbidden_tokens=None, deterministic_synthesis=False
262     ):
263         to_generate = (ar_mask.sum(0) > 0).nonzero()
264         if to_generate.min() > 0:
265             self(
266                 BracketedSequence(input, 0, to_generate.min())
267             )  # Needed to initialize the model's cache
268         for s in range(to_generate.min(), to_generate.max() + 1):
269             output = self(BracketedSequence(input, s, 1)).x
270             logits = output[:, s]
271             if forbidden_tokens is not None:
272                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
273             if deterministic_synthesis:
274                 t_next = logits.argmax(1)
275             else:
276                 dist = torch.distributions.categorical.Categorical(logits=logits)
277                 t_next = dist.sample()
278             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
279
280
281 ######################################################################
282
283 if __name__ == "__main__":
284     print("Basic check.")
285
286     vocabulary_size = 3
287     x = torch.randint(vocabulary_size, (1, 5))
288
289     model = MyGPT(
290         vocabulary_size=vocabulary_size,
291         dim_model=4,
292         dim_keys=2,
293         dim_hidden=2,
294         nb_heads=2,
295         nb_blocks=1,
296         dropout=0.1,
297         causal=True,
298     )
299
300     model.eval()
301
302     y1 = model(BracketedSequence(x)).x
303     y2 = torch.randn_like(y1)
304     for s in range(x.size(1)):
305         z = model(BracketedSequence(x, s, 1))
306         y2[:, s] = z.slice()
307
308     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
309
310 ######################################################################