Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d):
480             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
481
482         self.caterpillar_length = caterpillar_length
483         self.caterpillar_height = caterpillar_height
484         self.attention_dropout = attention_dropout
485
486         self.proba_gate_dropout = 0.0
487
488         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
489         self.b_G = nn.Parameter(
490             torch.full(
491                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
492             )
493         )
494
495         self.w_K = randw(nb_heads, dim_qk, dim_model)
496         self.w_V = randw(nb_heads, dim_v, dim_model)
497         self.w_Q = randw(nb_heads, dim_qk, dim_model)
498         self.w_O = randw(dim_v * nb_heads, dim_model)
499
500         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
501         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
502
503     def reset_inner_loss(self):
504         self.acc_attention = 0
505         self.acc_nb = 0
506
507     def get_inner_loss(self):
508         # warnings.warn("l2 regularization", RuntimeWarning)
509         # return (self.acc_attention / self.acc_nb).pow(2).sum()
510         return torch.tensor([0], device=self.w_Q.device)
511
512     def forward(self, bs):
513         # Dimensions to make the source a bit clearer, that's needed
514
515         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
516
517         N = bs.x.size(0)
518         T = bs.x.size(1)
519         H = self.w_V.size(0)
520         DV = self.w_V.size(1)
521         DK = self.w_K.size(1)
522         DM = self.w_O.size(1)
523         CH = self.caterpillar_height
524         CL = self.caterpillar_length
525
526         assert (
527             t0 >= CL and (t1 - t0) % CL == 0
528         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
529
530         # We cache values to deal efficiently with auto-regression
531
532         if bs.init_cache:
533             self.rec_V = X.new_zeros(N, CH, T, DV)
534             self.rec_K = X.new_zeros(N, CH, T, DK)
535             # We start the recurrent sequences with optimizable
536             # initial values. No idea if it helps.
537             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
538             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
539
540             self.cache_Y = X.new_zeros(N, T, DM)
541
542         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
543         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
544
545         ######################################################################
546         # Compute the recurrent state
547
548         # This is the Gating sequence that modulates the storing of
549         # the new key and value in the CH pairs of the current
550         # stack. There are CH independent gating values, which means
551         # that the current K/V may be stored in multiple pairs of the
552         # recurrent state, or not at all.
553
554         G = (
555             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
556         ).sigmoid()
557
558         # Clip the gating to avoid values greater than 1 when several
559         # heads hit the same row
560
561         G = G / G.sum(1, keepdim=True).clamp(min=1)
562
563         # We prepare the arguments for the parallel scan
564
565         A = 1 - G.sum(1)
566         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
567         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
568
569         # We start from cached values, which matters in inference
570
571         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
572         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
573
574         ######################################################################
575
576         if self.training and self.proba_gate_dropout > 0.0:
577             # This is a better implementation of "flashbacks".  A is
578             # NxExT where e is the caterpillar's row.
579
580             warnings.warn("gate dropout", RuntimeWarning)
581             epsilon = 0.5
582
583         #################################################################
584         # Associative scan
585
586         # Here there is a trick: Since the stack at position t is
587         # computed by updating that at position t-CL, the parallel
588         # scan operates with a period of CL. To do so we split the
589         # sequence indexing in two axes, the second of size CL, and
590         # run the parallel scan using the first as the sequence index.
591
592         A = A.unflatten(2, (-1, CL))
593         gated_V = gated_V.unflatten(2, (-1, CL))
594         gated_K = gated_K.unflatten(2, (-1, CL))
595
596         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
597         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
598
599         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
600         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
601
602         ######################################################################
603         # compute the readout
604
605         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
606
607         # We build tensors NxHxTxFxL where N is the sample index, H
608         # the head, T the time, F the row in the caterpillar, and L
609         # the column in the caterpillar
610
611         windowed_V = moving_window(
612             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
613         )
614
615         windowed_K = moving_window(
616             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
617         )
618
619         # We have an attention score for each of the CHxCL values
620
621         ar = torch.einsum(
622             "nhtd,nftld->nhtfl",
623             Q,
624             windowed_K,
625         ) / math.sqrt(DK)
626
627         # softmax can operate only on one dimension, hence the
628         # flattening
629
630         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
631
632         ar = F.dropout(ar, self.attention_dropout, self.training)
633
634         # Compute the output for each head, flatten to concatenate
635
636         Y = torch.einsum(
637             "nhtfl,nftld->nthd",
638             ar,
639             windowed_V,
640         ).flatten(2)
641
642         # Compute the final output
643
644         self.cache_Y[:, t0:t1] = Y @ self.w_O
645
646         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
647
648
649 ##############################
650
651
652 class QKVAttention(nn.Module):
653     def __init__(
654         self,
655         dim_model,
656         dim_qk,
657         dim_v,
658         nb_heads=1,
659         causal=False,
660         attention_dropout=0.0,
661     ):
662         super().__init__()
663
664         def randw(*d):
665             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
666
667         self.causal = causal
668         self.attention_dropout = attention_dropout
669         self.record_attention = False
670
671         self.w_q = randw(nb_heads, dim_qk, dim_model)
672         self.w_k = randw(nb_heads, dim_qk, dim_model)
673         self.w_v = randw(nb_heads, dim_v, dim_model)
674         self.w_o = randw(dim_v * nb_heads, dim_model)
675
676     def forward(self, bs):
677         x_q = bs.x
678
679         assert (
680             self.causal or bs.complete()
681         ), "Partial evaluation is only possible for causal models"
682
683         if bs.init_cache:
684             self.cache_k = x_q.new_zeros(
685                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
686             )
687             self.cache_v = x_q.new_zeros(
688                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
689             )
690             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
691
692         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
693
694         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
695             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
696         )
697         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
698             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
699         )
700
701         a = torch.einsum(
702             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
703         ) / math.sqrt(self.w_q.size(1))
704
705         if self.causal:
706             if bs.init_cache:
707                 self.cache_attzero = (
708                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
709                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
710                 )
711             a = a.masked_fill(
712                 self.cache_attzero[
713                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
714                 ],
715                 float("-inf"),
716             )
717
718         a = a.softmax(dim=3)
719
720         if self.record_attention:
721             self.a = a
722
723         a = F.dropout(a, self.attention_dropout, self.training)
724
725         y = torch.einsum(
726             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
727         ).flatten(2)
728
729         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
730
731         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
732
733
734 ##############################
735
736
737 class MyGPT(nn.Module):
738     def __init__(
739         self,
740         vocabulary_size,
741         dim_model,
742         dim_keys,
743         dim_hidden,
744         nb_heads,
745         nb_blocks,
746         nb_lines=None,
747         caterpillar_height=None,
748         causal=False,
749         dropout=0.0,
750         len_max=1e5,
751         attention_layer="kvrec",
752     ):
753         super().__init__()
754
755         assert attention_layer in {
756             "mha",
757             "dumbrec",
758             "kvrec",
759             "caterpillar",
760         }, f"Unknown attention operator {attention_layer}."
761
762         if attention_layer == "caterpillar":
763             assert nb_lines % caterpillar_height == 0
764             self.caterpillar_length = nb_lines // caterpillar_height
765             self.caterpillar_height = caterpillar_height
766         else:
767             self.caterpillar_length = -1
768             self.caterpillar_height = -1
769
770         assert dim_model % nb_heads == 0
771
772         self.embedding = nn.Sequential(
773             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
774             AddPositionalEncoding(len_max),
775         )
776
777         trunk_blocks = []
778
779         def attlayer():
780             if attention_layer == "mha":
781                 return QKVAttention(
782                     dim_model=dim_model,
783                     dim_qk=dim_keys,
784                     dim_v=dim_model // nb_heads,
785                     nb_heads=nb_heads,
786                     causal=causal,
787                     attention_dropout=dropout,
788                 )
789             elif attention_layer == "dumbrec":
790                 return DumbRec(
791                     dim_model=dim_model,
792                     dim_qk=dim_keys,
793                     dim_v=dim_model // nb_heads,
794                     nb_heads=nb_heads,
795                     nb_lines=nb_lines,
796                     attention_dropout=dropout,
797                 )
798             elif attention_layer == "kvrec":
799                 return KVRec(
800                     dim_model=dim_model,
801                     dim_qk=dim_keys,
802                     dim_v=dim_model // nb_heads,
803                     nb_heads=nb_heads,
804                     nb_lines=nb_lines,
805                     attention_dropout=dropout,
806                 )
807             elif attention_layer == "caterpillar":
808                 return Caterpillar(
809                     dim_model=dim_model,
810                     dim_qk=dim_keys,
811                     dim_v=dim_model // nb_heads,
812                     nb_heads=nb_heads,
813                     caterpillar_length=self.caterpillar_length,
814                     caterpillar_height=self.caterpillar_height,
815                     attention_dropout=dropout,
816                 )
817             else:
818                 raise ValueError(f"Unknown attention type {attention_layer}.")
819
820         for b in range(nb_blocks):
821             trunk_blocks += [
822                 WithResidual(
823                     CacheWrapper(nn.LayerNorm((dim_model,))),
824                     attlayer(),
825                 ),
826                 WithResidual(
827                     CacheWrapper(
828                         nn.LayerNorm((dim_model,)),
829                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
830                         nn.ReLU(),
831                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
832                         nn.Dropout(dropout),
833                     ),
834                 ),
835             ]
836
837         self.trunk = nn.Sequential(*trunk_blocks)
838
839         self.readout = CacheWrapper(
840             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
841         )
842
843         with torch.no_grad():
844             for m in self.modules():
845                 if isinstance(m, nn.Embedding):
846                     m.weight.normal_(mean=0, std=2e-2)
847                 elif isinstance(m, nn.LayerNorm):
848                     m.bias.zero_()
849                     m.weight.fill_(1.0)
850
851         self.reset_inner_loss()
852
853     def forward(self, bs):
854         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
855
856         # To make the code simpler in the Caterpillar layer, we pad
857         # here. It's unclear if/how much it hurts computationaly by
858         # increasing the sequence length for the other layers
859
860         if self.caterpillar_length > 0:
861             original_nb = bs.nb
862             if bs.nb % self.caterpillar_length > 0:
863                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
864
865             bs = BracketedSequence(
866                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
867                 bs.first + self.caterpillar_length,
868                 bs.nb,
869                 bs.init_cache,
870             )
871
872         bs = self.embedding(bs)
873         bs = self.trunk(bs)
874         bs = self.readout(bs)
875
876         if self.caterpillar_length > 0:
877             bs = BracketedSequence(
878                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
879                 bs.first - self.caterpillar_length,
880                 original_nb,
881                 bs.init_cache,
882             )
883
884         return bs
885
886     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
887     # 1s where tokens should be generated. The others are kept
888     # unchanged.
889
890     def masked_inplace_autoregression(
891         self,
892         input_src,
893         ar_mask_src,
894         forbidden_tokens=None,
895         deterministic_synthesis=False,
896     ):
897         input = input_src.to(self.readout.f.weight.device)
898         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
899         to_generate = (ar_mask.sum(0) > 0).nonzero()
900         if to_generate.min() > 0:
901             self(
902                 BracketedSequence(input, 0, to_generate.min(), True)
903             )  # Needed to initialize the model's cache
904         for s in range(to_generate.min(), to_generate.max() + 1):
905             output = self(BracketedSequence(input, s, 1, s == 0)).x
906             logits = output[:, s]
907             if forbidden_tokens is not None:
908                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
909             if deterministic_synthesis:
910                 t_next = logits.argmax(1)
911             else:
912                 dist = torch.distributions.categorical.Categorical(logits=logits)
913                 t_next = dist.sample()
914             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
915
916         input_src.copy_(input)
917
918     def reset_inner_loss(self):
919         for m in self.modules():
920             if m is not self and hasattr(m, "reset_inner_loss"):
921                 m.reset_inner_loss()
922
923     def get_inner_loss(self):
924         l = torch.tensor([0.0], device=self.readout.f.weight.device)
925         for m in self.modules():
926             if m is not self and hasattr(m, "get_inner_loss"):
927                 l += m.get_inner_loss()
928         return l
929
930     def record_attention(self, v=True):
931         for m in self.modules():
932             if isinstance(m, QKVAttention):
933                 m.record_attention = v
934
935     def retrieve_attention(self):
936         a = []
937         for m in self.modules():
938             if isinstance(m, QKVAttention):
939                 a.append(m.a)
940         return a
941
942
943 ######################################################################
944
945 if __name__ == "__main__":
946     print("Basic check.")
947
948     m = Caterpillar(
949         dim_model=4,
950         dim_qk=3,
951         dim_v=7,
952         nb_heads=1,
953         caterpillar_length=7,
954         caterpillar_height=3,
955         attention_dropout=0.0,
956     )
957
958     m.reset_inner_loss()
959     x = torch.randn(1, 21 + 2 * 7, 4)
960     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
961     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
962     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
963     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
964     print((y1 - y2).abs().max())
965     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
966     exit(0)
967
968     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
969
970     vocabulary_size = 128
971     x = torch.randint(vocabulary_size, (6, 1024))
972
973     model = MyGPT(
974         vocabulary_size=vocabulary_size,
975         dim_model=512,
976         dim_keys=64,
977         dim_hidden=2048,
978         nb_heads=8,
979         nb_lines=128,
980         nb_blocks=12,
981         dropout=0.1,
982         causal=True,
983     )
984
985     x = x.to(device)
986     model.to(device)
987
988     import time, sys
989
990     # import torchvision.models as models
991     # from torch.profiler import profile, record_function, ProfilerActivity
992
993     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
994     # with record_function("model_inference"):
995
996     model.eval()
997     for i in range(3):
998         start_time = time.perf_counter()
999         for k in range(10):
1000             model(BracketedSequence(x))
1001         duration = time.perf_counter() - start_time
1002         print(duration)
1003         sys.stdout.flush()
1004
1005     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1006     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1007
1008     # print("##############################################################")
1009     # y2 = torch.randn_like(y1)
1010     # for s in range(x.size(1)):
1011     # z = model(BracketedSequence(x, s, 1))
1012     # y2[:, s : s + 1] = z.slice()
1013
1014     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1015
1016 ######################################################################