Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_model,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_model)
203         self.w_qr = randw(nb_heads, dim_qk, dim_model)
204         # self.w_k = randw(nb_heads, dim_qk, dim_model)
205         self.w_v = randw(nb_heads, dim_v, dim_model)
206         self.w_o = randw(dim_v * nb_heads, dim_model)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_model,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_model)
332         self.w_qr = randw(nb_heads, dim_qk, dim_model)
333         self.w_k = randw(nb_heads, dim_qk, dim_model)
334         self.w_v = randw(nb_heads, dim_v, dim_model)
335         self.w_o = randw(dim_v * nb_heads, dim_model)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
447
448
449 def moving_window(x, dim, win_dim, win_size):
450     size, stride = x.size(), x.stride()
451     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452     size = size[:win_dim] + (win_size,) + size[win_dim:]
453     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
454
455     return x.as_strided(size=size, stride=stride)
456
457
458 ##############################
459
460
461 class Caterpillar(nn.Module):
462     def __init__(
463         self,
464         dim_model,
465         dim_qk,
466         dim_v,
467         nb_heads,
468         caterpillar_length,
469         caterpillar_height,
470         attention_dropout=0.0,
471         len_max=1e5,
472     ):
473         super().__init__()
474
475         warnings.warn("Caterpillar", RuntimeWarning)
476
477         def randw(*d):
478             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
479
480         self.caterpillar_length = caterpillar_length
481         self.caterpillar_height = caterpillar_height
482         self.attention_dropout = attention_dropout
483
484         warnings.warn("flash back", RuntimeWarning)
485         self.proba_flashback = 0.1
486
487         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
488         self.b_G = nn.Parameter(
489             torch.full(
490                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
491             )
492         )
493
494         self.w_K = randw(nb_heads, dim_qk, dim_model)
495         self.w_V = randw(nb_heads, dim_v, dim_model)
496         self.w_Q = randw(nb_heads, dim_qk, dim_model)
497         self.w_O = randw(dim_v * nb_heads, dim_model)
498
499         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
500         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
501
502     def reset_inner_loss(self):
503         self.acc_attention = 0
504         self.acc_nb = 0
505
506     def get_inner_loss(self):
507         # warnings.warn("l2 regularization", RuntimeWarning)
508         # return (self.acc_attention / self.acc_nb).pow(2).sum()
509         return torch.tensor([0], device=self.w_Q.device)
510
511     def forward(self, bs):
512         # Dimensions to make the source a bit clearer, that's needed
513
514         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
515
516         N = bs.x.size(0)
517         T = bs.x.size(1)
518         H = self.w_V.size(0)
519         DV = self.w_V.size(1)
520         DK = self.w_K.size(1)
521         DM = self.w_O.size(1)
522         CH = self.caterpillar_height
523         CL = self.caterpillar_length
524
525         assert (
526             t0 >= CL and (t1 - t0) % CL == 0
527         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
528
529         # We cache values to deal efficiently with auto-regression
530
531         if bs.init_cache:
532             self.rec_V = X.new_zeros(N, CH, T, DV)
533             self.rec_K = X.new_zeros(N, CH, T, DK)
534             # We start the recurrent sequences with optimizable
535             # initial values. No idea if it helps.
536             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
537             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
538
539             self.cache_Y = X.new_zeros(N, T, DM)
540
541         ######################################################################
542         # Compute the recurrent state
543
544         # This is the Gating sequence that modulates the storing of
545         # the new key and value in the CH pairs of the current
546         # stack. The CH gating values are independent, which means
547         # that the current K/V could be stored in multiple pairs of the
548         # recurrent state, or not at all.
549
550         G = (
551             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
552         ).sigmoid()
553
554         # That bas a bad idea
555         # G = F.dropout(G, self.attention_dropout, self.training)
556
557         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
558         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
559
560         # We prepare the arguments for the parallel scan
561
562         A = 1 - G.sum(1)
563         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
564         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
565
566         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
567         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
568
569         # Here there is a trick: Since the stack at time t is computed
570         # by updating that at time t-L, the parallel scan operates
571         # with a period of L. To do so we split the time indexing in
572         # two axes, the second of size CL, and run the parallel scan
573         # using the other as the sequence index.
574
575         A = A.unflatten(2, (-1, CL))
576         gated_V = gated_V.unflatten(2, (-1, CL))
577         gated_K = gated_K.unflatten(2, (-1, CL))
578
579         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
580         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
581
582         # Put back the sequence index
583
584         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
585         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
586
587         if self.training and self.proba_flashback > 0.0:
588             # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,)
589
590             # This piece of code makes the assumption that there is
591             # nothing informative before t0, otherwise we'd have to
592             # implement a cache for V and K too. This should not be
593             # too much of a problem since this is used only during
594             # train, where full sequence are available
595
596             n = torch.arange(N, device=X.device)[:, None, None, None]
597             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
598             dv = torch.arange(DV, device=X.device)[None, None, None, :]
599             dk = torch.arange(DK, device=X.device)[None, None, None, :]
600
601             u = (
602                 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
603             ) * CL
604
605             src_time = t - u - t0
606             src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
607
608             mask_V = (
609                 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
610             ).long()
611             self.rec_V[:, :, t0:t1] = (
612                 mask_V * V[n, src_head, src_time, dv]
613                 + (1 - mask_V) * self.rec_V[:, :, t0:t1]
614             )
615
616             mask_K = (
617                 torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback
618             ).long()
619             self.rec_K[:, :, t0:t1] = (
620                 mask_K * K[n, src_head, src_time, dk]
621                 + (1 - mask_K) * self.rec_K[:, :, t0:t1]
622             )
623
624         ######################################################################
625         # compute the readout
626
627         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
628
629         # We build tensors NxHxTxFxL where N is the sample index, H
630         # the head, T the time, F the row in the caterpillar, and L
631         # the column in the caterpillar
632
633         windowed_V = moving_window(
634             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
635         )
636
637         windowed_K = moving_window(
638             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
639         )
640
641         # We have an attention score for each of the CHxCL values
642
643         ar = torch.einsum(
644             "nhtd,nftld->nhtfl",
645             Q,
646             windowed_K,
647         ) / math.sqrt(DK)
648
649         # softmax can operate only on one dimension, hence the
650         # flattening
651
652         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
653
654         ar = F.dropout(ar, self.attention_dropout, self.training)
655
656         # Compute the output for each head, flatten to concatenate
657
658         Y = torch.einsum(
659             "nhtfl,nftld->nthd",
660             ar,
661             windowed_V,
662         ).flatten(2)
663
664         # Compute the final output
665
666         self.cache_Y[:, t0:t1] = Y @ self.w_O
667
668         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
669
670
671 ##############################
672
673
674 class QKVAttention(nn.Module):
675     def __init__(
676         self,
677         dim_model,
678         dim_qk,
679         dim_v,
680         nb_heads=1,
681         causal=False,
682         attention_dropout=0.0,
683     ):
684         super().__init__()
685
686         def randw(*d):
687             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
688
689         self.causal = causal
690         self.attention_dropout = attention_dropout
691         self.record_attention = False
692
693         self.w_q = randw(nb_heads, dim_qk, dim_model)
694         self.w_k = randw(nb_heads, dim_qk, dim_model)
695         self.w_v = randw(nb_heads, dim_v, dim_model)
696         self.w_o = randw(dim_v * nb_heads, dim_model)
697
698     def forward(self, bs):
699         x_q = bs.x
700
701         assert (
702             self.causal or bs.complete()
703         ), "Partial evaluation is only possible for causal models"
704
705         if bs.init_cache:
706             self.cache_k = x_q.new_zeros(
707                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
708             )
709             self.cache_v = x_q.new_zeros(
710                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
711             )
712             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
713
714         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
715
716         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
717             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
718         )
719         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
720             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
721         )
722
723         a = torch.einsum(
724             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
725         ) / math.sqrt(self.w_q.size(1))
726
727         if self.causal:
728             if bs.init_cache:
729                 self.cache_attzero = (
730                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
731                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
732                 )
733             a = a.masked_fill(
734                 self.cache_attzero[
735                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
736                 ],
737                 float("-inf"),
738             )
739
740         a = a.softmax(dim=3)
741
742         if self.record_attention:
743             self.a = a
744
745         a = F.dropout(a, self.attention_dropout, self.training)
746
747         y = torch.einsum(
748             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
749         ).flatten(2)
750
751         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
752
753         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
754
755
756 ##############################
757
758
759 class MyGPT(nn.Module):
760     def __init__(
761         self,
762         vocabulary_size,
763         dim_model,
764         dim_keys,
765         dim_hidden,
766         nb_heads,
767         nb_blocks,
768         nb_lines=None,
769         caterpillar_height=None,
770         dim_rec_v=-1,
771         causal=False,
772         dropout=0.0,
773         len_max=1e5,
774         attention_layer="kvrec",
775     ):
776         super().__init__()
777
778         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
779
780         if attention_layer == "caterpillar":
781             assert nb_lines % caterpillar_height == 0
782             self.caterpillar_length = nb_lines // caterpillar_height
783             self.caterpillar_height = caterpillar_height
784         else:
785             self.caterpillar_length = -1
786             self.caterpillar_height = -1
787
788         assert dim_model % nb_heads == 0
789
790         self.embedding = nn.Sequential(
791             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
792             AddPositionalEncoding(len_max),
793         )
794
795         trunk_blocks = []
796
797         def attlayer():
798             if attention_layer == "mha":
799                 return QKVAttention(
800                     dim_model=dim_model,
801                     dim_qk=dim_keys,
802                     dim_v=dim_model // nb_heads,
803                     nb_heads=nb_heads,
804                     causal=causal,
805                     attention_dropout=dropout,
806                 )
807             elif attention_layer == "dumbrec":
808                 return DumbRec(
809                     dim_model=dim_model,
810                     dim_qk=dim_keys,
811                     dim_v=dim_rec_v,
812                     nb_heads=nb_heads,
813                     nb_lines=nb_lines,
814                     attention_dropout=dropout,
815                 )
816             elif attention_layer == "kvrec":
817                 return KVRec(
818                     dim_model=dim_model,
819                     dim_qk=dim_keys,
820                     dim_v=dim_rec_v,
821                     nb_heads=nb_heads,
822                     nb_lines=nb_lines,
823                     attention_dropout=dropout,
824                 )
825             elif attention_layer == "caterpillar":
826                 return Caterpillar(
827                     dim_model=dim_model,
828                     dim_qk=dim_keys,
829                     dim_v=dim_rec_v,
830                     nb_heads=nb_heads,
831                     caterpillar_length=self.caterpillar_length,
832                     caterpillar_height=self.caterpillar_height,
833                     attention_dropout=dropout,
834                 )
835             else:
836                 raise ValueError(f"Unknown attention type {attention_layer}.")
837
838         for b in range(nb_blocks):
839             trunk_blocks += [
840                 WithResidual(
841                     CacheWrapper(nn.LayerNorm((dim_model,))),
842                     attlayer(),
843                 ),
844                 WithResidual(
845                     CacheWrapper(
846                         nn.LayerNorm((dim_model,)),
847                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
848                         nn.ReLU(),
849                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
850                         nn.Dropout(dropout),
851                     ),
852                 ),
853             ]
854
855         self.trunk = nn.Sequential(*trunk_blocks)
856
857         self.readout = CacheWrapper(
858             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
859         )
860
861         with torch.no_grad():
862             for m in self.modules():
863                 if isinstance(m, nn.Embedding):
864                     m.weight.normal_(mean=0, std=2e-2)
865                 elif isinstance(m, nn.LayerNorm):
866                     m.bias.zero_()
867                     m.weight.fill_(1.0)
868
869         self.reset_inner_loss()
870
871     def forward(self, bs):
872         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
873
874         # To make the code simpler in the Caterpillar layer, we pad
875         # here. It's unclear if/how much it hurts computationaly by
876         # increasing the sequence length for the other layers
877
878         if self.caterpillar_length > 0:
879             original_nb = bs.nb
880             if bs.nb % self.caterpillar_length > 0:
881                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
882
883             bs = BracketedSequence(
884                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
885                 bs.first + self.caterpillar_length,
886                 bs.nb,
887                 bs.init_cache,
888             )
889
890         bs = self.embedding(bs)
891         bs = self.trunk(bs)
892         bs = self.readout(bs)
893
894         if self.caterpillar_length > 0:
895             bs = BracketedSequence(
896                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
897                 bs.first - self.caterpillar_length,
898                 original_nb,
899                 bs.init_cache,
900             )
901
902         return bs
903
904     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
905     # 1s where tokens should be generated. The others are kept
906     # unchanged.
907
908     def masked_inplace_autoregression(
909         self,
910         input_src,
911         ar_mask_src,
912         forbidden_tokens=None,
913         deterministic_synthesis=False,
914     ):
915         input = input_src.to(self.readout.f.weight.device)
916         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
917         to_generate = (ar_mask.sum(0) > 0).nonzero()
918         if to_generate.min() > 0:
919             self(
920                 BracketedSequence(input, 0, to_generate.min(), True)
921             )  # Needed to initialize the model's cache
922         for s in range(to_generate.min(), to_generate.max() + 1):
923             output = self(BracketedSequence(input, s, 1, s == 0)).x
924             logits = output[:, s]
925             if forbidden_tokens is not None:
926                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
927             if deterministic_synthesis:
928                 t_next = logits.argmax(1)
929             else:
930                 dist = torch.distributions.categorical.Categorical(logits=logits)
931                 t_next = dist.sample()
932             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
933
934         input_src.copy_(input)
935
936     def reset_inner_loss(self):
937         for m in self.modules():
938             if m is not self and hasattr(m, "reset_inner_loss"):
939                 m.reset_inner_loss()
940
941     def get_inner_loss(self):
942         l = torch.tensor([0.0], device=self.readout.f.weight.device)
943         for m in self.modules():
944             if m is not self and hasattr(m, "get_inner_loss"):
945                 l += m.get_inner_loss()
946         return l
947
948     def record_attention(self, v=True):
949         for m in self.modules():
950             if isinstance(m, QKVAttention):
951                 m.record_attention = v
952
953     def retrieve_attention(self):
954         a = []
955         for m in self.modules():
956             if isinstance(m, QKVAttention):
957                 a.append(m.a)
958         return a
959
960
961 ######################################################################
962
963 if __name__ == "__main__":
964     print("Basic check.")
965
966     m = Caterpillar(
967         dim_model=4,
968         dim_qk=3,
969         dim_v=7,
970         nb_heads=1,
971         caterpillar_length=7,
972         caterpillar_height=3,
973         attention_dropout=0.0,
974     )
975
976     m.reset_inner_loss()
977     x = torch.randn(1, 21 + 2 * 7, 4)
978     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
979     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
980     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
981     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
982     print((y1 - y2).abs().max())
983     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
984     exit(0)
985
986     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
987
988     vocabulary_size = 128
989     x = torch.randint(vocabulary_size, (6, 1024))
990
991     model = MyGPT(
992         vocabulary_size=vocabulary_size,
993         dim_model=512,
994         dim_keys=64,
995         dim_hidden=2048,
996         nb_heads=8,
997         nb_lines=128,
998         nb_blocks=12,
999         dropout=0.1,
1000         causal=True,
1001     )
1002
1003     x = x.to(device)
1004     model.to(device)
1005
1006     import time, sys
1007
1008     # import torchvision.models as models
1009     # from torch.profiler import profile, record_function, ProfilerActivity
1010
1011     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1012     # with record_function("model_inference"):
1013
1014     model.eval()
1015     for i in range(3):
1016         start_time = time.perf_counter()
1017         for k in range(10):
1018             model(BracketedSequence(x))
1019         duration = time.perf_counter() - start_time
1020         print(duration)
1021         sys.stdout.flush()
1022
1023     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1024     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1025
1026     # print("##############################################################")
1027     # y2 = torch.randn_like(y1)
1028     # for s in range(x.size(1)):
1029     # z = model(BracketedSequence(x, s, 1))
1030     # y2[:, s : s + 1] = z.slice()
1031
1032     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1033
1034 ######################################################################