Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d, amplitude=None):
480             if amplitude is None:
481                 amplitude = 1 / math.sqrt(d[-1])
482             return nn.Parameter(amplitude * torch.randn(*d))
483
484         self.caterpillar_length = caterpillar_length
485         self.caterpillar_height = caterpillar_height
486         self.attention_dropout = attention_dropout
487
488         self.proba_gate_dropout = 0.0
489
490         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
491         self.b_G = nn.Parameter(
492             torch.full(
493                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
494             )
495         )
496
497         self.w_K = randw(nb_heads, dim_qk, dim_model)
498         self.w_V = randw(nb_heads, dim_v, dim_model)
499         self.w_Q = randw(nb_heads, dim_qk, dim_model)
500         self.w_O = randw(dim_v * nb_heads, dim_model)
501
502         self.init_K_rec = randw(
503             caterpillar_height,
504             caterpillar_length,
505             dim_qk,
506         )
507         self.init_V_rec = randw(
508             caterpillar_height,
509             caterpillar_length,
510             dim_v,
511         )
512
513     def reset_inner_loss(self):
514         self.acc_attention = 0
515         self.acc_nb = 0
516
517     def get_inner_loss(self):
518         # warnings.warn("l2 regularization", RuntimeWarning)
519         # return (self.acc_attention / self.acc_nb).pow(2).sum()
520         return torch.tensor([0], device=self.w_Q.device)
521
522     def forward(self, bs):
523         # Dimensions to make the source a bit clearer, that's needed
524
525         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
526
527         N = bs.x.size(0)
528         T = bs.x.size(1)
529         H = self.w_V.size(0)
530         DV = self.w_V.size(1)
531         DK = self.w_K.size(1)
532         DM = self.w_O.size(1)
533         R = self.caterpillar_height
534         L = self.caterpillar_length
535
536         assert (
537             t0 >= L and (t1 - t0) % L == 0
538         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
539
540         # We cache values to deal efficiently with auto-regression
541
542         if bs.init_cache:
543             self.rec_V = X.new_zeros(N, R, T, DV)
544             self.rec_K = X.new_zeros(N, R, T, DK)
545             # We start the recurrent sequences with optimizable
546             # initial values. No idea if it helps.
547             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
548             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
549
550             self.cache_Y = X.new_zeros(N, T, DM)
551
552         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
553         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
554
555         ######################################################################
556         # Compute the recurrent state
557
558         # This is the Gating sequence that modulates the storing of
559         # the new key and value in the R pairs of the current
560         # stack. There are R independent gating values, which means
561         # that the current K/V may be stored in multiple pairs of the
562         # recurrent state, or not at all.
563
564         G = (
565             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
566         ).sigmoid()
567
568         ######################################################################
569         # Roll the gating indexes
570
571         warnings.warn("rotating barrel", RuntimeWarning)
572         n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
573         h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
574         r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
575         t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
576         r_barrel = (r_barrel + t_barrel + t0) % R
577
578         # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}")
579
580         G = G[n_barrel, h_barrel, r_barrel, t_barrel]
581
582         # print(G.sum())
583
584         ######################################################################
585         # The "flashbacks"
586
587         if self.training and self.proba_gate_dropout > 0.0:
588             # This is a better implementation of "flashbacks".
589
590             # G is NxHxExT where e is the caterpillar's row.
591
592             warnings.warn("gate dropout", RuntimeWarning)
593             epsilon = 0.5
594
595             dropout_head = (
596                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
597                 .expand_as(G)
598                 .float()
599             )
600
601             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
602
603             dropout_active = (
604                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
605             ).long()
606
607             dropout_head *= dropout_active
608             dropout_tail *= dropout_active
609
610             G = (
611                 G
612                 + dropout_head * (1 - epsilon - G.detach())
613                 - dropout_tail * G.detach()
614             )
615
616         ######################################################################
617
618         # We prepare the arguments for the parallel scan
619
620         # Clip the gating to avoid values greater than 1 when several
621         # heads hit the same row
622
623         G = G / G.sum(1, keepdim=True).clamp(min=1)
624
625         A = 1 - G.sum(1)
626
627         # warnings.warn("harmonic recurrence", RuntimeWarning)
628         # har = torch.arange(t0, t1, device = G.device).float() + 1
629         # A = har / (har + 1)
630         # G = G / har
631
632         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
633         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
634
635         # We start from cached values, which matters in inference
636
637         init_rec_V = self.rec_V[:, :, t0 - L : t0]
638         init_rec_K = self.rec_K[:, :, t0 - L : t0]
639
640         #################################################################
641         # Associative scan
642
643         # Here there is a trick: Since the stack at position t is
644         # computed by updating that at position t-L, the parallel
645         # scan operates with a period of L. To do so we split the
646         # sequence indexing in two axes, the second of size L, and
647         # run the parallel scan using the first as the sequence index.
648
649         A = A.unflatten(2, (-1, L))
650         gated_V = gated_V.unflatten(2, (-1, L))
651         gated_K = gated_K.unflatten(2, (-1, L))
652
653         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
654         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
655
656         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
657         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
658
659         ######################################################################
660         # compute the readout
661
662         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
663
664         # We build tensors NxHxTxFxL where N is the sample index, H
665         # the head, T the time, F the row in the caterpillar, and L
666         # the column in the caterpillar
667
668         windowed_V = moving_window(
669             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
670         )
671
672         windowed_K = moving_window(
673             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
674         )
675
676         # We have an attention score for each of the RxL values
677
678         ar = torch.einsum(
679             "nhtd,nftld->nhtfl",
680             Q,
681             windowed_K,
682         ) / math.sqrt(DK)
683
684         # softmax can operate only on one dimension, hence the
685         # flattening
686
687         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
688
689         ar = F.dropout(ar, self.attention_dropout, self.training)
690
691         # Compute the output for each head, flatten to concatenate
692
693         Y = torch.einsum(
694             "nhtfl,nftld->nthd",
695             ar,
696             windowed_V,
697         ).flatten(2)
698
699         # Compute the final output
700
701         self.cache_Y[:, t0:t1] = Y @ self.w_O
702
703         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
704
705
706 ##############################
707
708
709 class QKVAttention(nn.Module):
710     def __init__(
711         self,
712         dim_model,
713         dim_qk,
714         dim_v,
715         nb_heads=1,
716         causal=False,
717         attention_dropout=0.0,
718     ):
719         super().__init__()
720
721         def randw(*d):
722             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
723
724         self.causal = causal
725         self.attention_dropout = attention_dropout
726         self.record_attention = False
727
728         self.w_q = randw(nb_heads, dim_qk, dim_model)
729         self.w_k = randw(nb_heads, dim_qk, dim_model)
730         self.w_v = randw(nb_heads, dim_v, dim_model)
731         self.w_o = randw(dim_v * nb_heads, dim_model)
732
733     def forward(self, bs):
734         x_q = bs.x
735
736         assert (
737             self.causal or bs.complete()
738         ), "Partial evaluation is only possible for causal models"
739
740         if bs.init_cache:
741             self.cache_k = x_q.new_zeros(
742                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
743             )
744             self.cache_v = x_q.new_zeros(
745                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
746             )
747             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
748
749         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
750
751         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
752             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
753         )
754         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
755             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
756         )
757
758         a = torch.einsum(
759             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
760         ) / math.sqrt(self.w_q.size(1))
761
762         if self.causal:
763             if bs.init_cache:
764                 self.cache_attzero = (
765                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
766                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
767                 )
768             a = a.masked_fill(
769                 self.cache_attzero[
770                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
771                 ],
772                 float("-inf"),
773             )
774
775         a = a.softmax(dim=3)
776
777         if self.record_attention:
778             self.a = a
779
780         a = F.dropout(a, self.attention_dropout, self.training)
781
782         y = torch.einsum(
783             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
784         ).flatten(2)
785
786         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
787
788         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
789
790
791 ##############################
792
793
794 class MyGPT(nn.Module):
795     def __init__(
796         self,
797         vocabulary_size,
798         dim_model,
799         dim_keys,
800         dim_hidden,
801         nb_heads,
802         nb_blocks,
803         nb_lines=None,
804         caterpillar_height=None,
805         causal=False,
806         dropout=0.0,
807         len_max=1e5,
808         attention_layer="kvrec",
809     ):
810         super().__init__()
811
812         assert attention_layer in {
813             "mha",
814             "dumbrec",
815             "kvrec",
816             "caterpillar",
817         }, f"Unknown attention operator {attention_layer}."
818
819         if attention_layer == "caterpillar":
820             assert nb_lines % caterpillar_height == 0
821             self.caterpillar_length = nb_lines // caterpillar_height
822             self.caterpillar_height = caterpillar_height
823         else:
824             self.caterpillar_length = -1
825             self.caterpillar_height = -1
826
827         assert dim_model % nb_heads == 0
828
829         self.embedding = nn.Sequential(
830             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
831             AddPositionalEncoding(len_max),
832         )
833
834         trunk_blocks = []
835
836         def attlayer():
837             if attention_layer == "mha":
838                 return QKVAttention(
839                     dim_model=dim_model,
840                     dim_qk=dim_keys,
841                     dim_v=dim_model // nb_heads,
842                     nb_heads=nb_heads,
843                     causal=causal,
844                     attention_dropout=dropout,
845                 )
846             elif attention_layer == "dumbrec":
847                 return DumbRec(
848                     dim_model=dim_model,
849                     dim_qk=dim_keys,
850                     dim_v=dim_model // nb_heads,
851                     nb_heads=nb_heads,
852                     nb_lines=nb_lines,
853                     attention_dropout=dropout,
854                 )
855             elif attention_layer == "kvrec":
856                 return KVRec(
857                     dim_model=dim_model,
858                     dim_qk=dim_keys,
859                     dim_v=dim_model // nb_heads,
860                     nb_heads=nb_heads,
861                     nb_lines=nb_lines,
862                     attention_dropout=dropout,
863                 )
864             elif attention_layer == "caterpillar":
865                 return Caterpillar(
866                     dim_model=dim_model,
867                     dim_qk=dim_keys,
868                     dim_v=dim_model // nb_heads,
869                     nb_heads=nb_heads,
870                     caterpillar_length=self.caterpillar_length,
871                     caterpillar_height=self.caterpillar_height,
872                     attention_dropout=dropout,
873                 )
874             else:
875                 raise ValueError(f"Unknown attention type {attention_layer}.")
876
877         for b in range(nb_blocks):
878             trunk_blocks += [
879                 WithResidual(
880                     CacheWrapper(nn.LayerNorm((dim_model,))),
881                     attlayer(),
882                 ),
883                 WithResidual(
884                     CacheWrapper(
885                         nn.LayerNorm((dim_model,)),
886                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
887                         nn.ReLU(),
888                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
889                         nn.Dropout(dropout),
890                     ),
891                 ),
892             ]
893
894         self.trunk = nn.Sequential(*trunk_blocks)
895
896         self.readout = CacheWrapper(
897             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
898         )
899
900         with torch.no_grad():
901             for m in self.modules():
902                 if isinstance(m, nn.Embedding):
903                     m.weight.normal_(mean=0, std=2e-2)
904                 elif isinstance(m, nn.LayerNorm):
905                     m.bias.zero_()
906                     m.weight.fill_(1.0)
907
908         self.reset_inner_loss()
909
910     def forward(self, bs):
911         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
912
913         # To make the code simpler in the Caterpillar layer, we pad
914         # here. It's unclear if/how much it hurts computationaly by
915         # increasing the sequence length for the other layers
916
917         if self.caterpillar_length > 0:
918             original_nb = bs.nb
919             if bs.nb % self.caterpillar_length > 0:
920                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
921
922             bs = BracketedSequence(
923                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
924                 bs.first + self.caterpillar_length,
925                 bs.nb,
926                 bs.init_cache,
927             )
928
929         bs = self.embedding(bs)
930         bs = self.trunk(bs)
931         bs = self.readout(bs)
932
933         if self.caterpillar_length > 0:
934             bs = BracketedSequence(
935                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
936                 bs.first - self.caterpillar_length,
937                 original_nb,
938                 bs.init_cache,
939             )
940
941         return bs
942
943     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
944     # 1s where tokens should be generated. The others are kept
945     # unchanged.
946
947     def masked_inplace_autoregression(
948         self,
949         input_src,
950         ar_mask_src,
951         forbidden_tokens=None,
952         deterministic_synthesis=False,
953     ):
954         input = input_src.to(self.readout.f.weight.device)
955         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
956         to_generate = (ar_mask.sum(0) > 0).nonzero()
957         if to_generate.min() > 0:
958             self(
959                 BracketedSequence(input, 0, to_generate.min(), True)
960             )  # Needed to initialize the model's cache
961         for s in range(to_generate.min(), to_generate.max() + 1):
962             output = self(BracketedSequence(input, s, 1, s == 0)).x
963             logits = output[:, s]
964             if forbidden_tokens is not None:
965                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
966             if deterministic_synthesis:
967                 t_next = logits.argmax(1)
968             else:
969                 dist = torch.distributions.categorical.Categorical(logits=logits)
970                 t_next = dist.sample()
971             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
972
973         input_src.copy_(input)
974
975     def reset_inner_loss(self):
976         for m in self.modules():
977             if m is not self and hasattr(m, "reset_inner_loss"):
978                 m.reset_inner_loss()
979
980     def get_inner_loss(self):
981         l = torch.tensor([0.0], device=self.readout.f.weight.device)
982         for m in self.modules():
983             if m is not self and hasattr(m, "get_inner_loss"):
984                 l += m.get_inner_loss()
985         return l
986
987     def record_attention(self, v=True):
988         for m in self.modules():
989             if isinstance(m, QKVAttention):
990                 m.record_attention = v
991
992     def retrieve_attention(self):
993         a = []
994         for m in self.modules():
995             if isinstance(m, QKVAttention):
996                 a.append(m.a)
997         return a
998
999
1000 ######################################################################
1001
1002 if __name__ == "__main__":
1003     print("Basic check.")
1004
1005     m = Caterpillar(
1006         dim_model=4,
1007         dim_qk=3,
1008         dim_v=7,
1009         nb_heads=1,
1010         caterpillar_length=7,
1011         caterpillar_height=3,
1012         attention_dropout=0.0,
1013     )
1014
1015     m.reset_inner_loss()
1016     x = torch.randn(1, 21 + 2 * 7, 4)
1017     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1018     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1019     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1020     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1021     print((y1 - y2).abs().max())
1022     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1023     exit(0)
1024
1025     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1026
1027     vocabulary_size = 128
1028     x = torch.randint(vocabulary_size, (6, 1024))
1029
1030     model = MyGPT(
1031         vocabulary_size=vocabulary_size,
1032         dim_model=512,
1033         dim_keys=64,
1034         dim_hidden=2048,
1035         nb_heads=8,
1036         nb_lines=128,
1037         nb_blocks=12,
1038         dropout=0.1,
1039         causal=True,
1040     )
1041
1042     x = x.to(device)
1043     model.to(device)
1044
1045     import time, sys
1046
1047     # import torchvision.models as models
1048     # from torch.profiler import profile, record_function, ProfilerActivity
1049
1050     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1051     # with record_function("model_inference"):
1052
1053     model.eval()
1054     for i in range(3):
1055         start_time = time.perf_counter()
1056         for k in range(10):
1057             model(BracketedSequence(x))
1058         duration = time.perf_counter() - start_time
1059         print(duration)
1060         sys.stdout.flush()
1061
1062     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1063     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1064
1065     # print("##############################################################")
1066     # y2 = torch.randn_like(y1)
1067     # for s in range(x.size(1)):
1068     # z = model(BracketedSequence(x, s, 1))
1069     # y2[:, s : s + 1] = z.slice()
1070
1071     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1072
1073 ######################################################################