Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193         logger=print,
194         **kwargs,
195     ):
196         super().__init__()
197
198         def randw(*d):
199             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
200
201         self.nb_lines = nb_lines
202         self.attention_dropout = attention_dropout
203
204         self.k_star = randw(nb_lines, dim_qk)
205
206         self.w_qw = randw(nb_heads, dim_qk, dim_model)
207         self.w_qr = randw(nb_heads, dim_qk, dim_model)
208         # self.w_k = randw(nb_heads, dim_qk, dim_model)
209         self.w_v = randw(nb_heads, dim_v, dim_model)
210         self.w_o = randw(dim_v * nb_heads, dim_model)
211
212     def reset_inner_loss(self):
213         self.acc_attention = 0
214         self.acc_nb = 0
215
216     def get_inner_loss(self):
217         warnings.warn("l2 regularization", RuntimeWarning)
218         return (self.acc_attention / self.acc_nb).pow(2).sum()
219         # return torch.tensor([0], device=self.w_qw.device)
220
221     def forward(self, bs):
222         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
223
224         if bs.init_cache:
225             self.rec_v = x_q.new_zeros(
226                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
227             )
228             # self.rec_k = x_q.new_zeros(
229             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
230             # )
231             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
232
233         ######################################################################
234         # Prepare the keys
235
236         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
237
238         warnings.warn("rotating key barrel", RuntimeWarning)
239         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
240         t_barrel = torch.arange(t0, t1, device=k_star.device)
241         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
242         l_barrel = (
243             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
244         ) % k_star.size(0)
245         k_star = k_star[l_barrel, t_barrel]
246
247         ######################################################################
248         # Compute the recurrent state
249
250         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
251
252         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
253         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
254
255         aw = torch.einsum(
256             "nhtd,ltd->nhlt",
257             qw,
258             k_star,
259         ) / math.sqrt(self.w_qw.size(1))
260
261         aw = aw.softmax(dim=2)  # nhlt
262
263         if self.train:
264             self.acc_attention += aw.sum(dim=(0, 1, 3))
265             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
266
267         aw = F.dropout(aw, self.attention_dropout, self.training)
268
269         A = 1 - aw.sum(dim=1)  # nlt
270
271         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
272         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
273
274         if t0 == 0:
275             V0 = None
276             # K0 = None
277         else:
278             V0 = self.rec_v[:, :, t0 - 1]
279             # K0 = self.rec_k[:, :, t0 - 1]
280
281         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
282         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
283
284         ######################################################################
285         # compute the readout
286
287         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288
289         ar = torch.einsum(
290             "nhtd,ld->nhlt",
291             qr,
292             # self.rec_k[:, :, t0:t1],
293             self.k_star,
294         ) / math.sqrt(self.w_qr.size(1))
295
296         ar = ar.softmax(dim=2)  # nhlt
297
298         ar = F.dropout(ar, self.attention_dropout, self.training)
299
300         y = torch.einsum(
301             "nhlt,nltd->nthd",
302             ar,
303             self.rec_v[:, :, t0:t1],
304         ).flatten(2)
305
306         self.cache_y[:, t0:t1] = y @ self.w_o
307
308         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
309
310
311 ##############################
312
313
314 class KVRec(nn.Module):
315     def __init__(
316         self,
317         dim_model,
318         dim_qk,
319         dim_v,
320         nb_heads,
321         nb_lines,
322         attention_dropout=0.0,
323         len_max=1e5,
324         logger=print,
325         **kwargs,
326     ):
327         super().__init__()
328
329         def randw(*d):
330             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
331
332         self.nb_lines = nb_lines
333         self.attention_dropout = attention_dropout
334
335         self.k_star = randw(nb_lines, dim_qk)
336
337         self.w_qw = randw(nb_heads, dim_qk, dim_model)
338         self.w_qr = randw(nb_heads, dim_qk, dim_model)
339         self.w_k = randw(nb_heads, dim_qk, dim_model)
340         self.w_v = randw(nb_heads, dim_v, dim_model)
341         self.w_o = randw(dim_v * nb_heads, dim_model)
342
343     def reset_inner_loss(self):
344         self.acc_attention = 0
345         self.acc_nb = 0
346
347     def get_inner_loss(self):
348         warnings.warn("l2 regularization", RuntimeWarning)
349         return (self.acc_attention / self.acc_nb).pow(2).sum()
350         # return torch.tensor([0], device=self.w_qw.device)
351         # warnings.warn("side regularization", RuntimeWarning)
352         # return (
353         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
354         # )
355         # return torch.tensor([0], device=self.w_qw.device)
356
357     def forward(self, bs):
358         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
359
360         if bs.init_cache:
361             self.rec_v = x_q.new_zeros(
362                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
363             )
364             self.rec_k = x_q.new_zeros(
365                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
366             )
367             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
368
369         ######################################################################
370         # Prepare the keys
371
372         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
373
374         warnings.warn("rotating key barrel", RuntimeWarning)
375         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
376         t_barrel = torch.arange(t0, t1, device=k_star.device)
377         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
378         l_barrel = (
379             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
380         ) % k_star.size(0)
381         k_star = k_star[l_barrel, t_barrel]
382
383         ######################################################################
384         # Compute the recurrent state
385
386         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
387
388         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
389         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
390
391         aw = torch.einsum(
392             "nhtd,ltd->nhlt",
393             qw,
394             k_star,
395         ) / math.sqrt(self.w_qw.size(1))
396
397         aw = aw.softmax(dim=2)  # nhlt
398
399         if self.train:
400             # We want all the memory lines to be used similarly
401             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
402             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
403
404         aw = F.dropout(aw, self.attention_dropout, self.training)
405
406         A = 1 - aw.sum(dim=1)  # nlt
407
408         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
409         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
410
411         if t0 == 0:
412             V0 = None
413             K0 = None
414         else:
415             V0 = self.rec_v[:, :, t0 - 1]
416             K0 = self.rec_k[:, :, t0 - 1]
417
418         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
419         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
420
421         ######################################################################
422         # compute the readout
423
424         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
425
426         ar = torch.einsum(
427             "nhtd,nltd->nhlt",
428             qr,
429             self.rec_k[:, :, t0:t1],
430         ) / math.sqrt(self.w_qr.size(1))
431
432         ar = ar.softmax(dim=2)  # nhlt
433
434         ar = F.dropout(ar, self.attention_dropout, self.training)
435
436         y = torch.einsum(
437             "nhlt,nltd->nthd",
438             ar,
439             self.rec_v[:, :, t0:t1],
440         ).flatten(2)
441
442         self.cache_y[:, t0:t1] = y @ self.w_o
443
444         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
445
446
447 ##############################
448
449
450 # Returns a tensor with an additional index at rank win_dim, that move
451 # along the same dimension as dim, on a domain {0...win_size-1}, and
452 # dim is restricted on a domain reduced by win_size-1 values.
453
454
455 def moving_window(x, dim, win_dim, win_size):
456     size, stride = x.size(), x.stride()
457     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
458     size = size[:win_dim] + (win_size,) + size[win_dim:]
459     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
460
461     return x.as_strided(size=size, stride=stride)
462
463
464 ##############################
465
466
467 class Calibrator:
468     def __init__(self, w=None, b=None):
469         self.w = w
470         self.b = b
471         self.s, self.s_sq, self.n = 0, 0, 0
472         self.mean, self.std = 0, 0
473
474     def update(self, X):
475         X = X.detach()
476         self.s += X.sum(dim=0)
477         self.s_sq += X.pow(2).sum(dim=0)
478         self.n += X.size(0)
479
480     def moments(self):
481         mean = self.s / self.n
482         std = (self.s_sq / self.n - mean * mean).sqrt()
483         return mean, std
484
485     def normalize(self):
486         mean, std = self.moments()
487         if self.b is not None:
488             self.b.sub_(mean)
489         if self.w is not None:
490             self.w.div_(std)
491         result = mean - self.mean, std - self.std
492         self.mean, self.std = mean, std
493         self.s, self.s_sq, self.n = 0, 0, 0
494         return result
495
496
497 class Caterpillar(nn.Module):
498     def __init__(
499         self,
500         dim_model,
501         dim_qk,
502         dim_v,
503         nb_heads,
504         caterpillar_length,
505         caterpillar_height,
506         attention_dropout=0.0,
507         len_max=1e5,
508         logger=print,
509         **kwargs,
510     ):
511         super().__init__()
512
513         warnings.warn("Caterpillar", RuntimeWarning)
514
515         def randw(*d, amplitude=None):
516             if amplitude is None:
517                 amplitude = 1 / math.sqrt(d[-1])
518             return nn.Parameter(amplitude * torch.randn(*d))
519
520         self.caterpillar_length = caterpillar_length
521         self.caterpillar_height = caterpillar_height
522         self.attention_dropout = attention_dropout
523
524         ######################################################################
525         # sup_args
526
527         x = kwargs.get("gate_dropout")
528         if x is None:
529             self.proba_gate_dropout = 0.0
530         else:
531             self.proba_gate_dropout = float(x)
532
533         logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
534
535         x = kwargs.get("default_bg")
536         if x is None:
537             default_bg = -math.log(caterpillar_height - 1)
538         else:
539             default_bg = float(x)
540
541         logger(f"default_bg {default_bg}")
542
543         ######################################################################
544
545         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
546         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
547
548         self.w_K = randw(nb_heads, dim_qk, dim_model)
549         self.w_V = randw(nb_heads, dim_v, dim_model)
550         self.w_Q = randw(nb_heads, dim_qk, dim_model)
551         self.w_O = randw(dim_v * nb_heads, dim_model)
552
553         self.init_K_rec = randw(
554             caterpillar_height,
555             caterpillar_length,
556             dim_qk,
557         )
558         self.init_V_rec = randw(
559             caterpillar_height,
560             caterpillar_length,
561             dim_v,
562         )
563
564         self.calibrator_G = Calibrator()
565         self.calibrator_rec_V = Calibrator()
566         self.calibrator_rec_K = Calibrator()
567
568     def reset_inner_loss(self):
569         self.acc_attention = 0
570         self.acc_nb = 0
571
572     def get_inner_loss(self):
573         # warnings.warn("l2 regularization", RuntimeWarning)
574         # return (self.acc_attention / self.acc_nb).pow(2).sum()
575         return torch.tensor([0], device=self.w_Q.device)
576
577     def forward(self, bs):
578         # Dimensions to make the source a bit clearer, that's needed
579
580         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
581
582         N = bs.x.size(0)
583         T = bs.x.size(1)
584         H = self.w_V.size(0)
585         DV = self.w_V.size(1)
586         DK = self.w_K.size(1)
587         DM = self.w_O.size(1)
588         R = self.caterpillar_height
589         L = self.caterpillar_length
590
591         assert (
592             t0 >= L and (t1 - t0) % L == 0
593         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
594
595         # We cache values to deal efficiently with auto-regression
596
597         if bs.init_cache:
598             self.rec_V = X.new_zeros(N, R, T, DV)
599             self.rec_K = X.new_zeros(N, R, T, DK)
600             # We start the recurrent sequences with optimizable
601             # initial values. No idea if it helps.
602             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
603             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
604
605             self.cache_Y = X.new_zeros(N, T, DM)
606
607         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
608         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
609
610         ######################################################################
611         # Compute the recurrent state
612
613         # This is the Gating sequence that modulates the storing of
614         # the new key and value in the R pairs of the current
615         # stack. There are R independent gating values, which means
616         # that the current K/V may be stored in multiple pairs of the
617         # recurrent state, or not at all.
618
619         G = (
620             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
621         ).sigmoid()
622
623         self.calibrator_G.update(G.reshape(-1, G.size(-1)))
624
625         # warnings.warn("softmax gating", RuntimeWarning)
626
627         # G = (
628         # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
629         # ).softmax(dim=2)
630
631         ######################################################################
632         # The "flashbacks"
633
634         if self.training and self.proba_gate_dropout > 0.0:
635             # This is a better implementation of "flashbacks".
636
637             # G is NxHxExT where e is the caterpillar's row.
638
639             warnings.warn("gate dropout", RuntimeWarning)
640
641             kill = (
642                 torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
643             ).float()
644
645             alpha = G / (1 - self.proba_gate_dropout)
646
647             G = alpha * (1 - kill)
648
649         ######################################################################
650         # Clip the gating to avoid values greater than 1 when several
651         # heads hit the same row
652
653         G = G / G.sum(1, keepdim=True).clamp(min=1)
654
655         ######################################################################
656         # Roll the gating indexes
657
658         # warnings.warn("rotating barrel", RuntimeWarning)
659
660         # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
661         # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
662         # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
663         # G = G.gather(dim=2, index=r_barrel.expand_as(G))
664
665         # We prepare the arguments for the parallel scan
666
667         A = 1 - G.sum(1)
668
669         # warnings.warn("harmonic recurrence", RuntimeWarning)
670         # har = torch.arange(t0, t1, device = G.device).float() + 1
671         # A = har / (har + 1)
672         # G = G / har
673
674         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
675         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
676
677         # We start from cached values, which matters in inference
678
679         init_rec_V = self.rec_V[:, :, t0 - L : t0]
680         init_rec_K = self.rec_K[:, :, t0 - L : t0]
681
682         #################################################################
683         # Associative scan
684
685         # Here there is a trick: Since the stack at position t is
686         # computed by updating that at position t-L, the parallel
687         # scan operates with a period of L. To do so we split the
688         # sequence indexing in two axes, the second of size L, and
689         # run the parallel scan using the first as the sequence index.
690
691         A = A.unflatten(2, (-1, L))
692         gated_V = gated_V.unflatten(2, (-1, L))
693         gated_K = gated_K.unflatten(2, (-1, L))
694
695         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
696         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
697
698         next_V = next_V.flatten(2, 3)
699         next_K = next_K.flatten(2, 3)
700
701         self.calibrator_rec_V.update(
702             next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
703         )
704         self.calibrator_rec_K.update(
705             next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
706         )
707
708         self.rec_V[:, :, t0:t1] = next_V
709         self.rec_K[:, :, t0:t1] = next_K
710
711         ######################################################################
712         # compute the readout
713
714         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
715
716         # We build tensors NxHxTxFxL where N is the sample index, H
717         # the head, T the time, F the row in the caterpillar, and L
718         # the column in the caterpillar
719
720         windowed_V = moving_window(
721             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
722         )
723
724         windowed_K = moving_window(
725             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
726         )
727
728         # We have an attention score for each of the RxL values
729
730         ar = torch.einsum(
731             "nhtd,nftld->nhtfl",
732             Q,
733             windowed_K,
734         ) / math.sqrt(DK)
735
736         # softmax can operate only on one dimension, hence the
737         # flattening
738
739         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
740
741         ar = F.dropout(ar, self.attention_dropout, self.training)
742
743         # Compute the output for each head, flatten to concatenate
744
745         Y = torch.einsum(
746             "nhtfl,nftld->nthd",
747             ar,
748             windowed_V,
749         ).flatten(2)
750
751         # Compute the final output
752
753         self.cache_Y[:, t0:t1] = Y @ self.w_O
754
755         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
756
757
758 ##############################
759
760
761 class QKVAttention(nn.Module):
762     def __init__(
763         self,
764         dim_model,
765         dim_qk,
766         dim_v,
767         nb_heads=1,
768         causal=False,
769         attention_dropout=0.0,
770         logger=print,
771         **kwargs,
772     ):
773         super().__init__()
774
775         def randw(*d):
776             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
777
778         self.causal = causal
779         self.attention_dropout = attention_dropout
780         self.record_attention = False
781
782         self.w_q = randw(nb_heads, dim_qk, dim_model)
783         self.w_k = randw(nb_heads, dim_qk, dim_model)
784         self.w_v = randw(nb_heads, dim_v, dim_model)
785         self.w_o = randw(dim_v * nb_heads, dim_model)
786
787     def forward(self, bs):
788         x_q = bs.x
789
790         assert (
791             self.causal or bs.complete()
792         ), "Partial evaluation is only possible for causal models"
793
794         if bs.init_cache:
795             self.cache_k = x_q.new_zeros(
796                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
797             )
798             self.cache_v = x_q.new_zeros(
799                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
800             )
801             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
802
803         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
804
805         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
806             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
807         )
808         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
809             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
810         )
811
812         a = torch.einsum(
813             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
814         ) / math.sqrt(self.w_q.size(1))
815
816         if self.causal:
817             if bs.init_cache:
818                 self.cache_attzero = (
819                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
820                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
821                 )
822             a = a.masked_fill(
823                 self.cache_attzero[
824                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
825                 ],
826                 float("-inf"),
827             )
828
829         a = a.softmax(dim=3)
830
831         if self.record_attention:
832             self.a = a
833
834         a = F.dropout(a, self.attention_dropout, self.training)
835
836         y = torch.einsum(
837             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
838         ).flatten(2)
839
840         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
841
842         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
843
844
845 ##############################
846
847
848 class MyGPT(nn.Module):
849     def __init__(
850         self,
851         vocabulary_size,
852         dim_model,
853         dim_keys,
854         dim_hidden,
855         nb_heads,
856         nb_blocks,
857         nb_lines=None,
858         caterpillar_height=None,
859         causal=False,
860         dropout=0.0,
861         len_max=1e5,
862         attention_layer="kvrec",
863         logger=print,
864         **kwargs,
865     ):
866         super().__init__()
867
868         assert attention_layer in {
869             "mha",
870             "dumbrec",
871             "kvrec",
872             "caterpillar",
873         }, f"Unknown attention operator {attention_layer}."
874
875         if attention_layer == "caterpillar":
876             assert nb_lines % caterpillar_height == 0
877             self.caterpillar_length = nb_lines // caterpillar_height
878             self.caterpillar_height = caterpillar_height
879         else:
880             self.caterpillar_length = -1
881             self.caterpillar_height = -1
882
883         assert dim_model % nb_heads == 0
884
885         self.embedding = nn.Sequential(
886             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
887             AddPositionalEncoding(len_max),
888         )
889
890         trunk_blocks = []
891
892         def attlayer():
893             if attention_layer == "mha":
894                 return QKVAttention(
895                     dim_model=dim_model,
896                     dim_qk=dim_keys,
897                     dim_v=dim_model // nb_heads,
898                     nb_heads=nb_heads,
899                     causal=causal,
900                     attention_dropout=dropout,
901                     logger=logger,
902                     **kwargs,
903                 )
904             elif attention_layer == "dumbrec":
905                 return DumbRec(
906                     dim_model=dim_model,
907                     dim_qk=dim_keys,
908                     dim_v=dim_model // nb_heads,
909                     nb_heads=nb_heads,
910                     nb_lines=nb_lines,
911                     attention_dropout=dropout,
912                     logger=logger,
913                     **kwargs,
914                 )
915             elif attention_layer == "kvrec":
916                 return KVRec(
917                     dim_model=dim_model,
918                     dim_qk=dim_keys,
919                     dim_v=dim_model // nb_heads,
920                     nb_heads=nb_heads,
921                     nb_lines=nb_lines,
922                     attention_dropout=dropout,
923                     logger=logger,
924                     **kwargs,
925                 )
926             elif attention_layer == "caterpillar":
927                 return Caterpillar(
928                     dim_model=dim_model,
929                     dim_qk=dim_keys,
930                     dim_v=dim_model // nb_heads,
931                     nb_heads=nb_heads,
932                     caterpillar_length=self.caterpillar_length,
933                     caterpillar_height=self.caterpillar_height,
934                     attention_dropout=dropout,
935                     logger=logger,
936                     **kwargs,
937                 )
938             else:
939                 raise ValueError(f"Unknown attention type {attention_layer}.")
940
941         for b in range(nb_blocks):
942             trunk_blocks += [
943                 WithResidual(
944                     CacheWrapper(nn.LayerNorm((dim_model,))),
945                     attlayer(),
946                 ),
947                 WithResidual(
948                     CacheWrapper(
949                         nn.LayerNorm((dim_model,)),
950                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
951                         nn.ReLU(),
952                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
953                         nn.Dropout(dropout),
954                     ),
955                 ),
956             ]
957
958         self.trunk = nn.Sequential(*trunk_blocks)
959
960         self.readout = CacheWrapper(
961             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
962         )
963
964         with torch.no_grad():
965             for m in self.modules():
966                 if isinstance(m, nn.Embedding):
967                     m.weight.normal_(mean=0, std=2e-2)
968                 elif isinstance(m, nn.LayerNorm):
969                     m.bias.zero_()
970                     m.weight.fill_(1.0)
971
972         self.reset_inner_loss()
973
974     def forward(self, bs):
975         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
976
977         # To make the code simpler in the Caterpillar layer, we pad
978         # here. It's unclear if/how much it hurts computationaly by
979         # increasing the sequence length for the other layers
980
981         if self.caterpillar_length > 0:
982             original_nb = bs.nb
983             if bs.nb % self.caterpillar_length > 0:
984                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
985
986             bs = BracketedSequence(
987                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
988                 bs.first + self.caterpillar_length,
989                 bs.nb,
990                 bs.init_cache,
991             )
992
993         bs = self.embedding(bs)
994         bs = self.trunk(bs)
995         bs = self.readout(bs)
996
997         if self.caterpillar_length > 0:
998             bs = BracketedSequence(
999                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
1000                 bs.first - self.caterpillar_length,
1001                 original_nb,
1002                 bs.init_cache,
1003             )
1004
1005         return bs
1006
1007     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
1008     # 1s where tokens should be generated. The others are kept
1009     # unchanged.
1010
1011     def masked_inplace_autoregression(
1012         self,
1013         input_src,
1014         ar_mask_src,
1015         forbidden_tokens=None,
1016         deterministic_synthesis=False,
1017     ):
1018         input = input_src.to(self.readout.f.weight.device)
1019         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
1020         to_generate = (ar_mask.sum(0) > 0).nonzero()
1021         if to_generate.min() > 0:
1022             self(
1023                 BracketedSequence(input, 0, to_generate.min(), True)
1024             )  # Needed to initialize the model's cache
1025         for s in range(to_generate.min(), to_generate.max() + 1):
1026             output = self(BracketedSequence(input, s, 1, s == 0)).x
1027             logits = output[:, s]
1028             if forbidden_tokens is not None:
1029                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
1030             if deterministic_synthesis:
1031                 t_next = logits.argmax(1)
1032             else:
1033                 dist = torch.distributions.categorical.Categorical(logits=logits)
1034                 t_next = dist.sample()
1035             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1036
1037         input_src.copy_(input)
1038
1039     def reset_inner_loss(self):
1040         for m in self.modules():
1041             if m is not self and hasattr(m, "reset_inner_loss"):
1042                 m.reset_inner_loss()
1043
1044     def get_inner_loss(self):
1045         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1046         for m in self.modules():
1047             if m is not self and hasattr(m, "get_inner_loss"):
1048                 l += m.get_inner_loss()
1049         return l
1050
1051     def record_attention(self, v=True):
1052         for m in self.modules():
1053             if isinstance(m, QKVAttention):
1054                 m.record_attention = v
1055
1056     def retrieve_attention(self):
1057         a = []
1058         for m in self.modules():
1059             if isinstance(m, QKVAttention):
1060                 a.append(m.a)
1061         return a
1062
1063
1064 ######################################################################
1065
1066 if __name__ == "__main__":
1067     print("Basic check.")
1068
1069     m = Caterpillar(
1070         dim_model=4,
1071         dim_qk=3,
1072         dim_v=7,
1073         nb_heads=1,
1074         caterpillar_length=7,
1075         caterpillar_height=3,
1076         attention_dropout=0.0,
1077     )
1078
1079     m.reset_inner_loss()
1080     x = torch.randn(1, 21 + 2 * 7, 4)
1081     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1082     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1083     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1084     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1085     print((y1 - y2).abs().max())
1086     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1087     exit(0)
1088
1089     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1090
1091     vocabulary_size = 128
1092     x = torch.randint(vocabulary_size, (6, 1024))
1093
1094     model = MyGPT(
1095         vocabulary_size=vocabulary_size,
1096         dim_model=512,
1097         dim_keys=64,
1098         dim_hidden=2048,
1099         nb_heads=8,
1100         nb_lines=128,
1101         nb_blocks=12,
1102         dropout=0.1,
1103         causal=True,
1104     )
1105
1106     x = x.to(device)
1107     model.to(device)
1108
1109     import time, sys
1110
1111     # import torchvision.models as models
1112     # from torch.profiler import profile, record_function, ProfilerActivity
1113
1114     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1115     # with record_function("model_inference"):
1116
1117     model.eval()
1118     for i in range(3):
1119         start_time = time.perf_counter()
1120         for k in range(10):
1121             model(BracketedSequence(x))
1122         duration = time.perf_counter() - start_time
1123         print(duration)
1124         sys.stdout.flush()
1125
1126     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1127     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1128
1129     # print("##############################################################")
1130     # y2 = torch.randn_like(y1)
1131     # for s in range(x.size(1)):
1132     # z = model(BracketedSequence(x, s, 1))
1133     # y2[:, s : s + 1] = z.slice()
1134
1135     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1136
1137 ######################################################################