Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d, amplitude=None):
480             if amplitude is None:
481                 amplitude = 1 / math.sqrt(d[-1])
482             return nn.Parameter(amplitude * torch.randn(*d))
483
484         self.caterpillar_length = caterpillar_length
485         self.caterpillar_height = caterpillar_height
486         self.attention_dropout = attention_dropout
487
488         self.proba_gate_dropout = 0.0
489
490         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
491         self.b_G = nn.Parameter(
492             torch.full(
493                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
494             )
495         )
496
497         self.w_K = randw(nb_heads, dim_qk, dim_model)
498         self.w_V = randw(nb_heads, dim_v, dim_model)
499         self.w_Q = randw(nb_heads, dim_qk, dim_model)
500         self.w_O = randw(dim_v * nb_heads, dim_model)
501
502         self.init_K_rec = randw(
503             caterpillar_height,
504             caterpillar_length,
505             dim_qk,
506         )
507         self.init_V_rec = randw(
508             caterpillar_height,
509             caterpillar_length,
510             dim_v,
511         )
512
513     def reset_inner_loss(self):
514         self.acc_attention = 0
515         self.acc_nb = 0
516
517     def get_inner_loss(self):
518         # warnings.warn("l2 regularization", RuntimeWarning)
519         # return (self.acc_attention / self.acc_nb).pow(2).sum()
520         return torch.tensor([0], device=self.w_Q.device)
521
522     def forward(self, bs):
523         # Dimensions to make the source a bit clearer, that's needed
524
525         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
526
527         N = bs.x.size(0)
528         T = bs.x.size(1)
529         H = self.w_V.size(0)
530         DV = self.w_V.size(1)
531         DK = self.w_K.size(1)
532         DM = self.w_O.size(1)
533         R = self.caterpillar_height
534         L = self.caterpillar_length
535
536         assert (
537             t0 >= L and (t1 - t0) % L == 0
538         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
539
540         # We cache values to deal efficiently with auto-regression
541
542         if bs.init_cache:
543             self.rec_V = X.new_zeros(N, R, T, DV)
544             self.rec_K = X.new_zeros(N, R, T, DK)
545             # We start the recurrent sequences with optimizable
546             # initial values. No idea if it helps.
547             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
548             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
549
550             self.cache_Y = X.new_zeros(N, T, DM)
551
552         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
553         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
554
555         ######################################################################
556         # Compute the recurrent state
557
558         # This is the Gating sequence that modulates the storing of
559         # the new key and value in the R pairs of the current
560         # stack. There are R independent gating values, which means
561         # that the current K/V may be stored in multiple pairs of the
562         # recurrent state, or not at all.
563
564         G = (
565             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
566         ).sigmoid()
567
568         ######################################################################
569         # Roll the gating indexes
570
571         warnings.warn("rotating barrel", RuntimeWarning)
572
573         # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}")
574
575         n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
576         h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
577         r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
578         t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
579         r_barrel = (r_barrel + (t_barrel + t0) // L) % R
580
581         # GG = G.gather(dim=2,index=r_barrel)
582         G = G[n_barrel, h_barrel, r_barrel, t_barrel]
583
584         # print("SANITY", (GG-G).abs())
585         # exit(0)
586
587         ######################################################################
588         # The "flashbacks"
589
590         if self.training and self.proba_gate_dropout > 0.0:
591             # This is a better implementation of "flashbacks".
592
593             # G is NxHxExT where e is the caterpillar's row.
594
595             warnings.warn("gate dropout", RuntimeWarning)
596             epsilon = 0.5
597
598             dropout_head = (
599                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
600                 .expand_as(G)
601                 .float()
602             )
603
604             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
605
606             dropout_active = (
607                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
608             ).long()
609
610             dropout_head *= dropout_active
611             dropout_tail *= dropout_active
612
613             G = (
614                 G
615                 + dropout_head * (1 - epsilon - G.detach())
616                 - dropout_tail * G.detach()
617             )
618
619         ######################################################################
620
621         # We prepare the arguments for the parallel scan
622
623         # Clip the gating to avoid values greater than 1 when several
624         # heads hit the same row
625
626         G = G / G.sum(1, keepdim=True).clamp(min=1)
627
628         A = 1 - G.sum(1)
629
630         # warnings.warn("harmonic recurrence", RuntimeWarning)
631         # har = torch.arange(t0, t1, device = G.device).float() + 1
632         # A = har / (har + 1)
633         # G = G / har
634
635         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
636         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
637
638         # We start from cached values, which matters in inference
639
640         init_rec_V = self.rec_V[:, :, t0 - L : t0]
641         init_rec_K = self.rec_K[:, :, t0 - L : t0]
642
643         #################################################################
644         # Associative scan
645
646         # Here there is a trick: Since the stack at position t is
647         # computed by updating that at position t-L, the parallel
648         # scan operates with a period of L. To do so we split the
649         # sequence indexing in two axes, the second of size L, and
650         # run the parallel scan using the first as the sequence index.
651
652         A = A.unflatten(2, (-1, L))
653         gated_V = gated_V.unflatten(2, (-1, L))
654         gated_K = gated_K.unflatten(2, (-1, L))
655
656         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
657         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
658
659         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
660         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
661
662         ######################################################################
663         # compute the readout
664
665         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
666
667         # We build tensors NxHxTxFxL where N is the sample index, H
668         # the head, T the time, F the row in the caterpillar, and L
669         # the column in the caterpillar
670
671         windowed_V = moving_window(
672             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
673         )
674
675         windowed_K = moving_window(
676             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
677         )
678
679         # We have an attention score for each of the RxL values
680
681         ar = torch.einsum(
682             "nhtd,nftld->nhtfl",
683             Q,
684             windowed_K,
685         ) / math.sqrt(DK)
686
687         # softmax can operate only on one dimension, hence the
688         # flattening
689
690         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
691
692         ar = F.dropout(ar, self.attention_dropout, self.training)
693
694         # Compute the output for each head, flatten to concatenate
695
696         Y = torch.einsum(
697             "nhtfl,nftld->nthd",
698             ar,
699             windowed_V,
700         ).flatten(2)
701
702         # Compute the final output
703
704         self.cache_Y[:, t0:t1] = Y @ self.w_O
705
706         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
707
708
709 ##############################
710
711
712 class QKVAttention(nn.Module):
713     def __init__(
714         self,
715         dim_model,
716         dim_qk,
717         dim_v,
718         nb_heads=1,
719         causal=False,
720         attention_dropout=0.0,
721     ):
722         super().__init__()
723
724         def randw(*d):
725             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
726
727         self.causal = causal
728         self.attention_dropout = attention_dropout
729         self.record_attention = False
730
731         self.w_q = randw(nb_heads, dim_qk, dim_model)
732         self.w_k = randw(nb_heads, dim_qk, dim_model)
733         self.w_v = randw(nb_heads, dim_v, dim_model)
734         self.w_o = randw(dim_v * nb_heads, dim_model)
735
736     def forward(self, bs):
737         x_q = bs.x
738
739         assert (
740             self.causal or bs.complete()
741         ), "Partial evaluation is only possible for causal models"
742
743         if bs.init_cache:
744             self.cache_k = x_q.new_zeros(
745                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
746             )
747             self.cache_v = x_q.new_zeros(
748                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
749             )
750             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
751
752         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
753
754         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
755             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
756         )
757         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
758             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
759         )
760
761         a = torch.einsum(
762             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
763         ) / math.sqrt(self.w_q.size(1))
764
765         if self.causal:
766             if bs.init_cache:
767                 self.cache_attzero = (
768                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
769                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
770                 )
771             a = a.masked_fill(
772                 self.cache_attzero[
773                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
774                 ],
775                 float("-inf"),
776             )
777
778         a = a.softmax(dim=3)
779
780         if self.record_attention:
781             self.a = a
782
783         a = F.dropout(a, self.attention_dropout, self.training)
784
785         y = torch.einsum(
786             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
787         ).flatten(2)
788
789         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
790
791         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
792
793
794 ##############################
795
796
797 class MyGPT(nn.Module):
798     def __init__(
799         self,
800         vocabulary_size,
801         dim_model,
802         dim_keys,
803         dim_hidden,
804         nb_heads,
805         nb_blocks,
806         nb_lines=None,
807         caterpillar_height=None,
808         causal=False,
809         dropout=0.0,
810         len_max=1e5,
811         attention_layer="kvrec",
812     ):
813         super().__init__()
814
815         assert attention_layer in {
816             "mha",
817             "dumbrec",
818             "kvrec",
819             "caterpillar",
820         }, f"Unknown attention operator {attention_layer}."
821
822         if attention_layer == "caterpillar":
823             assert nb_lines % caterpillar_height == 0
824             self.caterpillar_length = nb_lines // caterpillar_height
825             self.caterpillar_height = caterpillar_height
826         else:
827             self.caterpillar_length = -1
828             self.caterpillar_height = -1
829
830         assert dim_model % nb_heads == 0
831
832         self.embedding = nn.Sequential(
833             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
834             AddPositionalEncoding(len_max),
835         )
836
837         trunk_blocks = []
838
839         def attlayer():
840             if attention_layer == "mha":
841                 return QKVAttention(
842                     dim_model=dim_model,
843                     dim_qk=dim_keys,
844                     dim_v=dim_model // nb_heads,
845                     nb_heads=nb_heads,
846                     causal=causal,
847                     attention_dropout=dropout,
848                 )
849             elif attention_layer == "dumbrec":
850                 return DumbRec(
851                     dim_model=dim_model,
852                     dim_qk=dim_keys,
853                     dim_v=dim_model // nb_heads,
854                     nb_heads=nb_heads,
855                     nb_lines=nb_lines,
856                     attention_dropout=dropout,
857                 )
858             elif attention_layer == "kvrec":
859                 return KVRec(
860                     dim_model=dim_model,
861                     dim_qk=dim_keys,
862                     dim_v=dim_model // nb_heads,
863                     nb_heads=nb_heads,
864                     nb_lines=nb_lines,
865                     attention_dropout=dropout,
866                 )
867             elif attention_layer == "caterpillar":
868                 return Caterpillar(
869                     dim_model=dim_model,
870                     dim_qk=dim_keys,
871                     dim_v=dim_model // nb_heads,
872                     nb_heads=nb_heads,
873                     caterpillar_length=self.caterpillar_length,
874                     caterpillar_height=self.caterpillar_height,
875                     attention_dropout=dropout,
876                 )
877             else:
878                 raise ValueError(f"Unknown attention type {attention_layer}.")
879
880         for b in range(nb_blocks):
881             trunk_blocks += [
882                 WithResidual(
883                     CacheWrapper(nn.LayerNorm((dim_model,))),
884                     attlayer(),
885                 ),
886                 WithResidual(
887                     CacheWrapper(
888                         nn.LayerNorm((dim_model,)),
889                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
890                         nn.ReLU(),
891                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
892                         nn.Dropout(dropout),
893                     ),
894                 ),
895             ]
896
897         self.trunk = nn.Sequential(*trunk_blocks)
898
899         self.readout = CacheWrapper(
900             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
901         )
902
903         with torch.no_grad():
904             for m in self.modules():
905                 if isinstance(m, nn.Embedding):
906                     m.weight.normal_(mean=0, std=2e-2)
907                 elif isinstance(m, nn.LayerNorm):
908                     m.bias.zero_()
909                     m.weight.fill_(1.0)
910
911         self.reset_inner_loss()
912
913     def forward(self, bs):
914         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
915
916         # To make the code simpler in the Caterpillar layer, we pad
917         # here. It's unclear if/how much it hurts computationaly by
918         # increasing the sequence length for the other layers
919
920         if self.caterpillar_length > 0:
921             original_nb = bs.nb
922             if bs.nb % self.caterpillar_length > 0:
923                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
924
925             bs = BracketedSequence(
926                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
927                 bs.first + self.caterpillar_length,
928                 bs.nb,
929                 bs.init_cache,
930             )
931
932         bs = self.embedding(bs)
933         bs = self.trunk(bs)
934         bs = self.readout(bs)
935
936         if self.caterpillar_length > 0:
937             bs = BracketedSequence(
938                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
939                 bs.first - self.caterpillar_length,
940                 original_nb,
941                 bs.init_cache,
942             )
943
944         return bs
945
946     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
947     # 1s where tokens should be generated. The others are kept
948     # unchanged.
949
950     def masked_inplace_autoregression(
951         self,
952         input_src,
953         ar_mask_src,
954         forbidden_tokens=None,
955         deterministic_synthesis=False,
956     ):
957         input = input_src.to(self.readout.f.weight.device)
958         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
959         to_generate = (ar_mask.sum(0) > 0).nonzero()
960         if to_generate.min() > 0:
961             self(
962                 BracketedSequence(input, 0, to_generate.min(), True)
963             )  # Needed to initialize the model's cache
964         for s in range(to_generate.min(), to_generate.max() + 1):
965             output = self(BracketedSequence(input, s, 1, s == 0)).x
966             logits = output[:, s]
967             if forbidden_tokens is not None:
968                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
969             if deterministic_synthesis:
970                 t_next = logits.argmax(1)
971             else:
972                 dist = torch.distributions.categorical.Categorical(logits=logits)
973                 t_next = dist.sample()
974             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
975
976         input_src.copy_(input)
977
978     def reset_inner_loss(self):
979         for m in self.modules():
980             if m is not self and hasattr(m, "reset_inner_loss"):
981                 m.reset_inner_loss()
982
983     def get_inner_loss(self):
984         l = torch.tensor([0.0], device=self.readout.f.weight.device)
985         for m in self.modules():
986             if m is not self and hasattr(m, "get_inner_loss"):
987                 l += m.get_inner_loss()
988         return l
989
990     def record_attention(self, v=True):
991         for m in self.modules():
992             if isinstance(m, QKVAttention):
993                 m.record_attention = v
994
995     def retrieve_attention(self):
996         a = []
997         for m in self.modules():
998             if isinstance(m, QKVAttention):
999                 a.append(m.a)
1000         return a
1001
1002
1003 ######################################################################
1004
1005 if __name__ == "__main__":
1006     print("Basic check.")
1007
1008     m = Caterpillar(
1009         dim_model=4,
1010         dim_qk=3,
1011         dim_v=7,
1012         nb_heads=1,
1013         caterpillar_length=7,
1014         caterpillar_height=3,
1015         attention_dropout=0.0,
1016     )
1017
1018     m.reset_inner_loss()
1019     x = torch.randn(1, 21 + 2 * 7, 4)
1020     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1021     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1022     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1023     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1024     print((y1 - y2).abs().max())
1025     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1026     exit(0)
1027
1028     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1029
1030     vocabulary_size = 128
1031     x = torch.randint(vocabulary_size, (6, 1024))
1032
1033     model = MyGPT(
1034         vocabulary_size=vocabulary_size,
1035         dim_model=512,
1036         dim_keys=64,
1037         dim_hidden=2048,
1038         nb_heads=8,
1039         nb_lines=128,
1040         nb_blocks=12,
1041         dropout=0.1,
1042         causal=True,
1043     )
1044
1045     x = x.to(device)
1046     model.to(device)
1047
1048     import time, sys
1049
1050     # import torchvision.models as models
1051     # from torch.profiler import profile, record_function, ProfilerActivity
1052
1053     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1054     # with record_function("model_inference"):
1055
1056     model.eval()
1057     for i in range(3):
1058         start_time = time.perf_counter()
1059         for k in range(10):
1060             model(BracketedSequence(x))
1061         duration = time.perf_counter() - start_time
1062         print(duration)
1063         sys.stdout.flush()
1064
1065     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1066     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1067
1068     # print("##############################################################")
1069     # y2 = torch.randn_like(y1)
1070     # for s in range(x.size(1)):
1071     # z = model(BracketedSequence(x, s, 1))
1072     # y2[:, s : s + 1] = z.slice()
1073
1074     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1075
1076 ######################################################################