Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_model,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_model)
203         self.w_qr = randw(nb_heads, dim_qk, dim_model)
204         # self.w_k = randw(nb_heads, dim_qk, dim_model)
205         self.w_v = randw(nb_heads, dim_v, dim_model)
206         self.w_o = randw(dim_v * nb_heads, dim_model)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_model,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_model)
332         self.w_qr = randw(nb_heads, dim_qk, dim_model)
333         self.w_k = randw(nb_heads, dim_qk, dim_model)
334         self.w_v = randw(nb_heads, dim_v, dim_model)
335         self.w_o = randw(dim_v * nb_heads, dim_model)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
447
448
449 def moving_window(x, dim, win_dim, win_size):
450     size, stride = x.size(), x.stride()
451     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452     size = size[:win_dim] + (win_size,) + size[win_dim:]
453     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
454
455     return x.as_strided(size=size, stride=stride)
456
457
458 ##############################
459
460 # This is one order of magnitude more complicated than I expected
461
462
463 def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device):
464     # starting flash backs
465     fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long()
466     fb_start[:, :, -CL:] = 0
467     fb_start[:, :, :CL] = 0
468
469     # Remove series longer than CL
470     fb_body = fb_start.clone()
471     fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)]
472     fb_body = fb_body.cumsum(dim=2)
473     fb_start = fb_start * (fb_body == 1)
474
475     # pick past starting source times
476     src_time = (
477         fb_start
478         * (
479             torch.rand(fb_start.size(), device=fb_start.device)
480             * (torch.arange(fb_start.size(2), device=fb_start.device) - CL)[
481                 None, None, :
482             ]
483         ).long()
484     )
485     src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL]
486     src_time = src_time.cumsum(dim=2)
487
488     src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device)
489     src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL]
490     src_head = src_head.cumsum(dim=2)
491
492     # combine
493     src_delta = fb_start.clone()
494     src_delta[:, :, CL:] -= fb_start[:, :, :-CL]
495     src_delta = src_delta.cumsum(dim=2)
496     src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL]
497     src_time += src_delta.cumsum(dim=2) - 1
498
499     return src_time, src_head
500
501
502 def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
503     N, H, CH = V.size(0), V.size(1), rec_V.size(1)
504
505     fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device)
506
507     fbt_V = fbt[:, :, :, None].expand_as(rec_V[:, :, t0:t1])
508     fbh_V = fbh[:, :, :, None].expand_as(rec_V[:, :, t0:t1])
509     t = fbt_V.clamp(min=0)
510     n = torch.arange(V.size(0), device=V.device)[:, None, None, None].expand_as(
511         rec_V[:, :, t0:t1]
512     )
513     d = torch.arange(V.size(3), device=V.device)[None, None, None, :].expand_as(
514         rec_V[:, :, t0:t1]
515     )
516     q = V[:, :, t0:t1][n, fbh_V, t, d]
517     rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0)
518
519     fbt_K = fbt[:, :, :, None].expand_as(rec_K[:, :, t0:t1])
520     fbh_K = fbh[:, :, :, None].expand_as(rec_K[:, :, t0:t1])
521     t = fbt_K.clamp(min=0)
522     n = torch.arange(K.size(0), device=K.device)[:, None, None, None].expand_as(
523         rec_K[:, :, t0:t1]
524     )
525     d = torch.arange(K.size(3), device=K.device)[None, None, None, :].expand_as(
526         rec_K[:, :, t0:t1]
527     )
528     q = K[:, :, t0:t1][n, fbh_K, t, d]
529     rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0)
530
531     # print("SANITY", (fbt_K >=0).float().sum()/fbt_K.numel())
532
533
534 ######################################################################
535
536
537 class Caterpillar(nn.Module):
538     def __init__(
539         self,
540         dim_model,
541         dim_qk,
542         dim_v,
543         nb_heads,
544         caterpillar_length,
545         caterpillar_height,
546         attention_dropout=0.0,
547         len_max=1e5,
548     ):
549         super().__init__()
550
551         warnings.warn("Caterpillar", RuntimeWarning)
552
553         def randw(*d):
554             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
555
556         self.caterpillar_length = caterpillar_length
557         self.caterpillar_height = caterpillar_height
558         self.attention_dropout = attention_dropout
559
560         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
561         self.b_G = nn.Parameter(
562             torch.full(
563                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
564             )
565         )
566
567         self.w_K = randw(nb_heads, dim_qk, dim_model)
568         self.w_V = randw(nb_heads, dim_v, dim_model)
569         self.w_Q = randw(nb_heads, dim_qk, dim_model)
570         self.w_O = randw(dim_v * nb_heads, dim_model)
571
572         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
573         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
574
575     def reset_inner_loss(self):
576         self.acc_attention = 0
577         self.acc_nb = 0
578
579     def get_inner_loss(self):
580         # warnings.warn("l2 regularization", RuntimeWarning)
581         # return (self.acc_attention / self.acc_nb).pow(2).sum()
582         return torch.tensor([0], device=self.w_Q.device)
583
584     def forward(self, bs):
585         # Dimensions to make the source a bit clearer, that's needed
586
587         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
588
589         N = bs.x.size(0)
590         T = bs.x.size(1)
591         DV = self.w_V.size(1)
592         DK = self.w_K.size(1)
593         DM = self.w_O.size(1)
594         CH = self.caterpillar_height
595         CL = self.caterpillar_length
596
597         assert (
598             t0 >= CL and (t1 - t0) % CL == 0
599         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
600
601         # We cache values to deal efficiently with auto-regression
602
603         if bs.init_cache:
604             self.rec_V = X.new_zeros(N, CH, T, DV)
605             self.rec_K = X.new_zeros(N, CH, T, DK)
606             # We start the recurrent sequences with optimizable
607             # initial values. No idea if it helps.
608             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
609             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
610
611             self.cache_Y = X.new_zeros(N, T, DM)
612
613         ######################################################################
614         # Compute the recurrent state
615
616         # This is the Gating sequence that modulates the storing of
617         # the new key and value in the CH pairs of the current
618         # stack. The CH gating values are independent, which means
619         # that the current K/V could be stored in multiple pairs of the
620         # recurrent state, or not at all.
621
622         G = (
623             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
624         ).sigmoid()
625
626         # That bas a bad idea
627         # G = F.dropout(G, self.attention_dropout, self.training)
628
629         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
630         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
631
632         # We prepare the arguments for the parallel scan
633
634         A = 1 - G.sum(1)
635         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
636         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
637
638         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
639         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
640
641         # Here there is a trick: Since the stack at time t is computed
642         # by updating that at time t-L, the parallel scan operates
643         # with a period of L. To do so we split the time indexing in
644         # two axes, the second of size CL, and run the parallel scan
645         # using the other as the sequence index.
646
647         A = A.unflatten(2, (-1, CL))
648         gated_V = gated_V.unflatten(2, (-1, CL))
649         gated_K = gated_K.unflatten(2, (-1, CL))
650
651         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
652         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
653
654         # Put back the sequence index
655
656         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
657         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
658
659         warnings.warn("flash back", RuntimeWarning)
660         if self.training:
661             insert_flash_back(self.rec_V, V, self.rec_K, K, t0, t1, CL, proba=1e-2 / CL)
662
663         ######################################################################
664         # compute the readout
665
666         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
667
668         # We build tensors NxHxTxFxL where N is the sample index, H
669         # the head, T the time, F the row in the caterpillar, and L
670         # the column in the caterpillar
671
672         windowed_V = moving_window(
673             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
674         )
675
676         windowed_K = moving_window(
677             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
678         )
679
680         # We have an attention score for each of the CHxCL values
681
682         ar = torch.einsum(
683             "nhtd,nftld->nhtfl",
684             Q,
685             windowed_K,
686         ) / math.sqrt(DK)
687
688         # softmax can operate only on one dimension, hence the
689         # flattening
690
691         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
692
693         ar = F.dropout(ar, self.attention_dropout, self.training)
694
695         # Compute the output for each head, flatten to concatenate
696
697         Y = torch.einsum(
698             "nhtfl,nftld->nthd",
699             ar,
700             windowed_V,
701         ).flatten(2)
702
703         # Compute the final output
704
705         self.cache_Y[:, t0:t1] = Y @ self.w_O
706
707         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
708
709
710 ##############################
711
712
713 class QKVAttention(nn.Module):
714     def __init__(
715         self,
716         dim_model,
717         dim_qk,
718         dim_v,
719         nb_heads=1,
720         causal=False,
721         attention_dropout=0.0,
722     ):
723         super().__init__()
724
725         def randw(*d):
726             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
727
728         self.causal = causal
729         self.attention_dropout = attention_dropout
730         self.record_attention = False
731
732         self.w_q = randw(nb_heads, dim_qk, dim_model)
733         self.w_k = randw(nb_heads, dim_qk, dim_model)
734         self.w_v = randw(nb_heads, dim_v, dim_model)
735         self.w_o = randw(dim_v * nb_heads, dim_model)
736
737     def forward(self, bs):
738         x_q = bs.x
739
740         assert (
741             self.causal or bs.complete()
742         ), "Partial evaluation is only possible for causal models"
743
744         if bs.init_cache:
745             self.cache_k = x_q.new_zeros(
746                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
747             )
748             self.cache_v = x_q.new_zeros(
749                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
750             )
751             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
752
753         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
754
755         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
756             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
757         )
758         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
759             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
760         )
761
762         a = torch.einsum(
763             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
764         ) / math.sqrt(self.w_q.size(1))
765
766         if self.causal:
767             if bs.init_cache:
768                 self.cache_attzero = (
769                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
770                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
771                 )
772             a = a.masked_fill(
773                 self.cache_attzero[
774                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
775                 ],
776                 float("-inf"),
777             )
778
779         a = a.softmax(dim=3)
780
781         if self.record_attention:
782             self.a = a
783
784         a = F.dropout(a, self.attention_dropout, self.training)
785
786         y = torch.einsum(
787             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
788         ).flatten(2)
789
790         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
791
792         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
793
794
795 ##############################
796
797
798 class MyGPT(nn.Module):
799     def __init__(
800         self,
801         vocabulary_size,
802         dim_model,
803         dim_keys,
804         dim_hidden,
805         nb_heads,
806         nb_blocks,
807         nb_lines=None,
808         caterpillar_height=None,
809         dim_rec_v=-1,
810         causal=False,
811         dropout=0.0,
812         len_max=1e5,
813         attention_layer="kvrec",
814     ):
815         super().__init__()
816
817         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
818
819         if attention_layer == "caterpillar":
820             assert nb_lines % caterpillar_height == 0
821             self.caterpillar_length = nb_lines // caterpillar_height
822             self.caterpillar_height = caterpillar_height
823         else:
824             self.caterpillar_length = -1
825             self.caterpillar_height = -1
826
827         assert dim_model % nb_heads == 0
828
829         self.embedding = nn.Sequential(
830             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
831             AddPositionalEncoding(len_max),
832         )
833
834         trunk_blocks = []
835
836         def attlayer():
837             if attention_layer == "mha":
838                 return QKVAttention(
839                     dim_model=dim_model,
840                     dim_qk=dim_keys,
841                     dim_v=dim_model // nb_heads,
842                     nb_heads=nb_heads,
843                     causal=causal,
844                     attention_dropout=dropout,
845                 )
846             elif attention_layer == "dumbrec":
847                 return DumbRec(
848                     dim_model=dim_model,
849                     dim_qk=dim_keys,
850                     dim_v=dim_rec_v,
851                     nb_heads=nb_heads,
852                     nb_lines=nb_lines,
853                     attention_dropout=dropout,
854                 )
855             elif attention_layer == "kvrec":
856                 return KVRec(
857                     dim_model=dim_model,
858                     dim_qk=dim_keys,
859                     dim_v=dim_rec_v,
860                     nb_heads=nb_heads,
861                     nb_lines=nb_lines,
862                     attention_dropout=dropout,
863                 )
864             elif attention_layer == "caterpillar":
865                 return Caterpillar(
866                     dim_model=dim_model,
867                     dim_qk=dim_keys,
868                     dim_v=dim_rec_v,
869                     nb_heads=nb_heads,
870                     caterpillar_length=self.caterpillar_length,
871                     caterpillar_height=self.caterpillar_height,
872                     attention_dropout=dropout,
873                 )
874             else:
875                 raise ValueError(f"Unknown attention type {attention_layer}.")
876
877         for b in range(nb_blocks):
878             trunk_blocks += [
879                 WithResidual(
880                     CacheWrapper(nn.LayerNorm((dim_model,))),
881                     attlayer(),
882                 ),
883                 WithResidual(
884                     CacheWrapper(
885                         nn.LayerNorm((dim_model,)),
886                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
887                         nn.ReLU(),
888                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
889                         nn.Dropout(dropout),
890                     ),
891                 ),
892             ]
893
894         self.trunk = nn.Sequential(*trunk_blocks)
895
896         self.readout = CacheWrapper(
897             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
898         )
899
900         with torch.no_grad():
901             for m in self.modules():
902                 if isinstance(m, nn.Embedding):
903                     m.weight.normal_(mean=0, std=2e-2)
904                 elif isinstance(m, nn.LayerNorm):
905                     m.bias.zero_()
906                     m.weight.fill_(1.0)
907
908         self.reset_inner_loss()
909
910     def forward(self, bs):
911         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
912
913         # To make the code simpler in the Caterpillar layer, we pad
914         # here. It's unclear if/how much it hurts computationaly by
915         # increasing the sequence length for the other layers
916
917         if self.caterpillar_length > 0:
918             original_nb = bs.nb
919             if bs.nb % self.caterpillar_length > 0:
920                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
921
922             bs = BracketedSequence(
923                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
924                 bs.first + self.caterpillar_length,
925                 bs.nb,
926                 bs.init_cache,
927             )
928
929         bs = self.embedding(bs)
930         bs = self.trunk(bs)
931         bs = self.readout(bs)
932
933         if self.caterpillar_length > 0:
934             bs = BracketedSequence(
935                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
936                 bs.first - self.caterpillar_length,
937                 original_nb,
938                 bs.init_cache,
939             )
940
941         return bs
942
943     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
944     # 1s where tokens should be generated. The others are kept
945     # unchanged.
946
947     def masked_inplace_autoregression(
948         self,
949         input_src,
950         ar_mask_src,
951         forbidden_tokens=None,
952         deterministic_synthesis=False,
953     ):
954         input = input_src.to(self.readout.f.weight.device)
955         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
956         to_generate = (ar_mask.sum(0) > 0).nonzero()
957         if to_generate.min() > 0:
958             self(
959                 BracketedSequence(input, 0, to_generate.min(), True)
960             )  # Needed to initialize the model's cache
961         for s in range(to_generate.min(), to_generate.max() + 1):
962             output = self(BracketedSequence(input, s, 1, s == 0)).x
963             logits = output[:, s]
964             if forbidden_tokens is not None:
965                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
966             if deterministic_synthesis:
967                 t_next = logits.argmax(1)
968             else:
969                 dist = torch.distributions.categorical.Categorical(logits=logits)
970                 t_next = dist.sample()
971             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
972
973         input_src.copy_(input)
974
975     def reset_inner_loss(self):
976         for m in self.modules():
977             if m is not self and hasattr(m, "reset_inner_loss"):
978                 m.reset_inner_loss()
979
980     def get_inner_loss(self):
981         l = torch.tensor([0.0], device=self.readout.f.weight.device)
982         for m in self.modules():
983             if m is not self and hasattr(m, "get_inner_loss"):
984                 l += m.get_inner_loss()
985         return l
986
987     def record_attention(self, v=True):
988         for m in self.modules():
989             if isinstance(m, QKVAttention):
990                 m.record_attention = v
991
992     def retrieve_attention(self):
993         a = []
994         for m in self.modules():
995             if isinstance(m, QKVAttention):
996                 a.append(m.a)
997         return a
998
999
1000 ######################################################################
1001
1002 if __name__ == "__main__":
1003     print("Basic check.")
1004
1005     m = Caterpillar(
1006         dim_model=4,
1007         dim_qk=3,
1008         dim_v=7,
1009         nb_heads=1,
1010         caterpillar_length=7,
1011         caterpillar_height=3,
1012         attention_dropout=0.0,
1013     )
1014
1015     m.reset_inner_loss()
1016     x = torch.randn(1, 21 + 2 * 7, 4)
1017     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1018     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1019     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1020     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1021     print((y1 - y2).abs().max())
1022     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1023     exit(0)
1024
1025     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1026
1027     vocabulary_size = 128
1028     x = torch.randint(vocabulary_size, (6, 1024))
1029
1030     model = MyGPT(
1031         vocabulary_size=vocabulary_size,
1032         dim_model=512,
1033         dim_keys=64,
1034         dim_hidden=2048,
1035         nb_heads=8,
1036         nb_lines=128,
1037         nb_blocks=12,
1038         dropout=0.1,
1039         causal=True,
1040     )
1041
1042     x = x.to(device)
1043     model.to(device)
1044
1045     import time, sys
1046
1047     # import torchvision.models as models
1048     # from torch.profiler import profile, record_function, ProfilerActivity
1049
1050     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1051     # with record_function("model_inference"):
1052
1053     model.eval()
1054     for i in range(3):
1055         start_time = time.perf_counter()
1056         for k in range(10):
1057             model(BracketedSequence(x))
1058         duration = time.perf_counter() - start_time
1059         print(duration)
1060         sys.stdout.flush()
1061
1062     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1063     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1064
1065     # print("##############################################################")
1066     # y2 = torch.randn_like(y1)
1067     # for s in range(x.size(1)):
1068     # z = model(BracketedSequence(x, s, 1))
1069     # y2[:, s : s + 1] = z.slice()
1070
1071     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1072
1073 ######################################################################