3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
18 from torch.nn import functional as F
24 ######################################################################
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
43 class BracketedSequence:
44 def __init__(self, x, first=None, nb=None, init_cache=None):
46 assert (first is None and nb is None and init_cache is None) or (
47 first is not None and nb is not None and init_cache is not None
50 self.first = 0 if first is None else first
51 self.nb = x.size(1) if nb is None else nb
52 self.init_cache = True if init_cache is None else init_cache
55 return self.x[:, self.first : self.first + self.nb]
58 return self.first == 0 and self.nb == self.x.size(1)
61 ######################################################################
64 class CacheWrapper(nn.Module):
65 def __init__(self, *f):
67 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
69 def forward(self, bs):
71 y = self.f(bs.slice())
72 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73 self.cache_y[:, bs.first : bs.first + bs.nb] = y
75 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76 assert bs.first + bs.nb <= self.cache_y.size(1)
77 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
79 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82 ##############################
85 class WithResidual(nn.Module):
86 def __init__(self, *f):
88 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
90 def forward(self, bs):
91 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94 ##############################
97 class AddPositionalEncoding(nn.Module):
98 def __init__(self, len_max):
100 self.len_max = len_max
102 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
104 def forward(self, bs):
106 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
114 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
116 self.cache_y = bs.x.new(bs.x.size())
118 self.cache_y[:, bs.first : bs.first + bs.nb] = (
119 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
128 # X is /.../xTxD A is /.../xT Y_init is /.../xD
131 def pscan_dim(A, X, Y_init, dim=-2):
133 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135 A = A.reshape(a, T, *s[dim + 1 : -1])
136 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143 Y = pscan.pscan(A, X, Y_init).reshape(s)
148 def pscan_shape(A, X, Y_init):
150 A = A.reshape(-1, s[-2])
151 X = X.reshape(-1, s[-2], s[-1])
154 Y_init = X.new_zeros(X.size(0), s[-1])
156 Y_init = Y_init.reshape(-1, s[-1])
158 Y = pscan.pscan(A, X, Y_init).reshape(s)
163 def nsum_shape(X, Y_init):
165 X = X.reshape(-1, s[-2], s[-1]) # ntd
167 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170 for k in range(X.size(1)):
172 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175 return torch.cat(result, dim=1).reshape(s)
178 ##############################
181 class DumbRec(nn.Module):
189 attention_dropout=0.0,
195 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
197 self.nb_lines = nb_lines
198 self.attention_dropout = attention_dropout
200 self.k_star = randw(nb_lines, dim_qk)
202 self.w_qw = randw(nb_heads, dim_qk, dim_model)
203 self.w_qr = randw(nb_heads, dim_qk, dim_model)
204 # self.w_k = randw(nb_heads, dim_qk, dim_model)
205 self.w_v = randw(nb_heads, dim_v, dim_model)
206 self.w_o = randw(dim_v * nb_heads, dim_model)
208 def reset_inner_loss(self):
209 self.acc_attention = 0
212 def get_inner_loss(self):
213 warnings.warn("l2 regularization", RuntimeWarning)
214 return (self.acc_attention / self.acc_nb).pow(2).sum()
215 # return torch.tensor([0], device=self.w_qw.device)
217 def forward(self, bs):
218 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221 self.rec_v = x_q.new_zeros(
222 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
224 # self.rec_k = x_q.new_zeros(
225 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
227 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
229 ######################################################################
232 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
234 warnings.warn("rotating key barrel", RuntimeWarning)
235 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236 t_barrel = torch.arange(t0, t1, device=k_star.device)
237 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
239 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
241 k_star = k_star[l_barrel, t_barrel]
243 ######################################################################
244 # Compute the recurrent state
246 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
248 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
255 ) / math.sqrt(self.w_qw.size(1))
257 aw = aw.softmax(dim=2) # nhlt
260 self.acc_attention += aw.sum(dim=(0, 1, 3))
261 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
263 aw = F.dropout(aw, self.attention_dropout, self.training)
265 A = 1 - aw.sum(dim=1) # nlt
267 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
274 V0 = self.rec_v[:, :, t0 - 1]
275 # K0 = self.rec_k[:, :, t0 - 1]
277 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
280 ######################################################################
281 # compute the readout
283 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288 # self.rec_k[:, :, t0:t1],
290 ) / math.sqrt(self.w_qr.size(1))
292 ar = ar.softmax(dim=2) # nhlt
294 ar = F.dropout(ar, self.attention_dropout, self.training)
299 self.rec_v[:, :, t0:t1],
302 self.cache_y[:, t0:t1] = y @ self.w_o
304 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307 ##############################
310 class KVRec(nn.Module):
318 attention_dropout=0.0,
324 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
326 self.nb_lines = nb_lines
327 self.attention_dropout = attention_dropout
329 self.k_star = randw(nb_lines, dim_qk)
331 self.w_qw = randw(nb_heads, dim_qk, dim_model)
332 self.w_qr = randw(nb_heads, dim_qk, dim_model)
333 self.w_k = randw(nb_heads, dim_qk, dim_model)
334 self.w_v = randw(nb_heads, dim_v, dim_model)
335 self.w_o = randw(dim_v * nb_heads, dim_model)
337 def reset_inner_loss(self):
338 self.acc_attention = 0
341 def get_inner_loss(self):
342 warnings.warn("l2 regularization", RuntimeWarning)
343 return (self.acc_attention / self.acc_nb).pow(2).sum()
344 # return torch.tensor([0], device=self.w_qw.device)
345 # warnings.warn("side regularization", RuntimeWarning)
347 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
349 # return torch.tensor([0], device=self.w_qw.device)
351 def forward(self, bs):
352 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355 self.rec_v = x_q.new_zeros(
356 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
358 self.rec_k = x_q.new_zeros(
359 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
361 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
363 ######################################################################
366 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
368 warnings.warn("rotating key barrel", RuntimeWarning)
369 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370 t_barrel = torch.arange(t0, t1, device=k_star.device)
371 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
373 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
375 k_star = k_star[l_barrel, t_barrel]
377 ######################################################################
378 # Compute the recurrent state
380 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
382 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
389 ) / math.sqrt(self.w_qw.size(1))
391 aw = aw.softmax(dim=2) # nhlt
394 # We want all the memory lines to be used similarly
395 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
396 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
398 aw = F.dropout(aw, self.attention_dropout, self.training)
400 A = 1 - aw.sum(dim=1) # nlt
402 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
409 V0 = self.rec_v[:, :, t0 - 1]
410 K0 = self.rec_k[:, :, t0 - 1]
412 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
415 ######################################################################
416 # compute the readout
418 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
423 self.rec_k[:, :, t0:t1],
424 ) / math.sqrt(self.w_qr.size(1))
426 ar = ar.softmax(dim=2) # nhlt
428 ar = F.dropout(ar, self.attention_dropout, self.training)
433 self.rec_v[:, :, t0:t1],
436 self.cache_y[:, t0:t1] = y @ self.w_o
438 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441 ##############################
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
449 def moving_window(x, dim, win_dim, win_size):
450 size, stride = x.size(), x.stride()
451 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452 size = size[:win_dim] + (win_size,) + size[win_dim:]
453 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
455 return x.as_strided(size=size, stride=stride)
458 ##############################
460 # This is one order of magnitude more complicated than I expected, not
461 # elegant, slow, hopefully not buggy
464 def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device):
465 # starting flash backs
466 fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long()
467 fb_start[:, :, -CL:] = 0
468 fb_start[:, :, :CL] = 0
470 # Remove series longer than CL
471 fb_body = fb_start.clone()
472 fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)]
473 fb_body = fb_body.cumsum(dim=2)
474 fb_start = fb_start * (fb_body == 1)
476 # Set a origin source time (starting time of the chunck to copy
477 # here) We set it as the current time minus a multiple of CL to be
478 # consistent with the "rolling" caterpillar
479 t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :]
480 src_time = fb_start * (
486 torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1)
490 src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL]
491 src_time = src_time.cumsum(dim=2)
493 src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device)
494 src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL]
495 src_head = src_head.cumsum(dim=2)
498 src_delta = fb_start.clone()
499 src_delta[:, :, CL:] -= fb_start[:, :, :-CL]
500 src_delta = src_delta.cumsum(dim=2)
501 src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL]
502 src_time += src_delta.cumsum(dim=2) - 1
504 return src_time, src_head
507 def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
508 N, H, CH = V.size(0), V.size(1), rec_V.size(1)
510 fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device)
512 fbt_V = fbt[:, :, :, None]
513 fbh_V = fbh[:, :, :, None]
514 t = fbt_V.clamp(min=0)
515 n = torch.arange(V.size(0), device=V.device)[:, None, None, None]
516 d = torch.arange(V.size(3), device=V.device)[None, None, None, :]
517 q = V[:, :, t0:t1][n, fbh_V, t, d]
518 rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0)
520 fbt_K = fbt[:, :, :, None]
521 fbh_K = fbh[:, :, :, None]
522 t = fbt_K.clamp(min=0)
523 n = torch.arange(K.size(0), device=K.device)[:, None, None, None]
524 d = torch.arange(K.size(3), device=K.device)[None, None, None, :]
525 q = K[:, :, t0:t1][n, fbh_K, t, d]
526 rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0)
529 ######################################################################
532 class Caterpillar(nn.Module):
541 attention_dropout=0.0,
546 warnings.warn("Caterpillar", RuntimeWarning)
549 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
551 self.caterpillar_length = caterpillar_length
552 self.caterpillar_height = caterpillar_height
553 self.attention_dropout = attention_dropout
555 warnings.warn("flash back", RuntimeWarning)
556 self.proba_flashback = 0.1
558 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
559 self.b_G = nn.Parameter(
561 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
565 self.w_K = randw(nb_heads, dim_qk, dim_model)
566 self.w_V = randw(nb_heads, dim_v, dim_model)
567 self.w_Q = randw(nb_heads, dim_qk, dim_model)
568 self.w_O = randw(dim_v * nb_heads, dim_model)
570 self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
571 self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
573 def reset_inner_loss(self):
574 self.acc_attention = 0
577 def get_inner_loss(self):
578 # warnings.warn("l2 regularization", RuntimeWarning)
579 # return (self.acc_attention / self.acc_nb).pow(2).sum()
580 return torch.tensor([0], device=self.w_Q.device)
582 def forward(self, bs):
583 # Dimensions to make the source a bit clearer, that's needed
585 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
590 DV = self.w_V.size(1)
591 DK = self.w_K.size(1)
592 DM = self.w_O.size(1)
593 CH = self.caterpillar_height
594 CL = self.caterpillar_length
597 t0 >= CL and (t1 - t0) % CL == 0
598 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
600 # We cache values to deal efficiently with auto-regression
603 self.rec_V = X.new_zeros(N, CH, T, DV)
604 self.rec_K = X.new_zeros(N, CH, T, DK)
605 # We start the recurrent sequences with optimizable
606 # initial values. No idea if it helps.
607 self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
608 self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
610 self.cache_Y = X.new_zeros(N, T, DM)
612 ######################################################################
613 # Compute the recurrent state
615 # This is the Gating sequence that modulates the storing of
616 # the new key and value in the CH pairs of the current
617 # stack. The CH gating values are independent, which means
618 # that the current K/V could be stored in multiple pairs of the
619 # recurrent state, or not at all.
622 torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
625 # That bas a bad idea
626 # G = F.dropout(G, self.attention_dropout, self.training)
628 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
629 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
631 # We prepare the arguments for the parallel scan
634 gated_V = torch.einsum("nhet,nhtd->netd", G, V)
635 gated_K = torch.einsum("nhet,nhtd->netd", G, K)
637 init_rec_V = self.rec_V[:, :, t0 - CL : t0]
638 init_rec_K = self.rec_K[:, :, t0 - CL : t0]
640 # Here there is a trick: Since the stack at time t is computed
641 # by updating that at time t-L, the parallel scan operates
642 # with a period of L. To do so we split the time indexing in
643 # two axes, the second of size CL, and run the parallel scan
644 # using the other as the sequence index.
646 A = A.unflatten(2, (-1, CL))
647 gated_V = gated_V.unflatten(2, (-1, CL))
648 gated_K = gated_K.unflatten(2, (-1, CL))
650 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
651 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
653 # Put back the sequence index
655 self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
656 self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
658 if self.training and self.proba_flashback:
667 # proba=self.proba_flashback / CL,
670 n = torch.arange(N, device=X.device)[:, None, None, None]
671 t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
672 dv = torch.arange(DV, device=X.device)[None, None, None, :]
673 dk = torch.arange(DK, device=X.device)[None, None, None, :]
676 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
679 src_time = t - u - t0
680 src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
683 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
685 self.rec_V[:, :, t0:t1] = (
686 mask_V * V[n, src_head, src_time, dv]
687 + (1 - mask_V) * self.rec_V[:, :, t0:t1]
691 torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback
693 self.rec_K[:, :, t0:t1] = (
694 mask_K * K[n, src_head, src_time, dk]
695 + (1 - mask_K) * self.rec_K[:, :, t0:t1]
698 ######################################################################
699 # compute the readout
701 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
703 # We build tensors NxHxTxFxL where N is the sample index, H
704 # the head, T the time, F the row in the caterpillar, and L
705 # the column in the caterpillar
707 windowed_V = moving_window(
708 self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
711 windowed_K = moving_window(
712 self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
715 # We have an attention score for each of the CHxCL values
723 # softmax can operate only on one dimension, hence the
726 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
728 ar = F.dropout(ar, self.attention_dropout, self.training)
730 # Compute the output for each head, flatten to concatenate
738 # Compute the final output
740 self.cache_Y[:, t0:t1] = Y @ self.w_O
742 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
745 ##############################
748 class QKVAttention(nn.Module):
756 attention_dropout=0.0,
761 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
764 self.attention_dropout = attention_dropout
765 self.record_attention = False
767 self.w_q = randw(nb_heads, dim_qk, dim_model)
768 self.w_k = randw(nb_heads, dim_qk, dim_model)
769 self.w_v = randw(nb_heads, dim_v, dim_model)
770 self.w_o = randw(dim_v * nb_heads, dim_model)
772 def forward(self, bs):
776 self.causal or bs.complete()
777 ), "Partial evaluation is only possible for causal models"
780 self.cache_k = x_q.new_zeros(
781 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
783 self.cache_v = x_q.new_zeros(
784 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
786 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
788 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
790 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
791 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
793 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
794 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
798 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
799 ) / math.sqrt(self.w_q.size(1))
803 self.cache_attzero = (
804 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
805 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
809 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
816 if self.record_attention:
819 a = F.dropout(a, self.attention_dropout, self.training)
822 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
825 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
827 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
830 ##############################
833 class MyGPT(nn.Module):
843 caterpillar_height=None,
848 attention_layer="kvrec",
852 assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
854 if attention_layer == "caterpillar":
855 assert nb_lines % caterpillar_height == 0
856 self.caterpillar_length = nb_lines // caterpillar_height
857 self.caterpillar_height = caterpillar_height
859 self.caterpillar_length = -1
860 self.caterpillar_height = -1
862 assert dim_model % nb_heads == 0
864 self.embedding = nn.Sequential(
865 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
866 AddPositionalEncoding(len_max),
872 if attention_layer == "mha":
876 dim_v=dim_model // nb_heads,
879 attention_dropout=dropout,
881 elif attention_layer == "dumbrec":
888 attention_dropout=dropout,
890 elif attention_layer == "kvrec":
897 attention_dropout=dropout,
899 elif attention_layer == "caterpillar":
905 caterpillar_length=self.caterpillar_length,
906 caterpillar_height=self.caterpillar_height,
907 attention_dropout=dropout,
910 raise ValueError(f"Unknown attention type {attention_layer}.")
912 for b in range(nb_blocks):
915 CacheWrapper(nn.LayerNorm((dim_model,))),
920 nn.LayerNorm((dim_model,)),
921 nn.Linear(in_features=dim_model, out_features=dim_hidden),
923 nn.Linear(in_features=dim_hidden, out_features=dim_model),
929 self.trunk = nn.Sequential(*trunk_blocks)
931 self.readout = CacheWrapper(
932 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
935 with torch.no_grad():
936 for m in self.modules():
937 if isinstance(m, nn.Embedding):
938 m.weight.normal_(mean=0, std=2e-2)
939 elif isinstance(m, nn.LayerNorm):
943 self.reset_inner_loss()
945 def forward(self, bs):
946 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
948 # To make the code simpler in the Caterpillar layer, we pad
949 # here. It's unclear if/how much it hurts computationaly by
950 # increasing the sequence length for the other layers
952 if self.caterpillar_length > 0:
954 if bs.nb % self.caterpillar_length > 0:
955 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
957 bs = BracketedSequence(
958 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
959 bs.first + self.caterpillar_length,
964 bs = self.embedding(bs)
966 bs = self.readout(bs)
968 if self.caterpillar_length > 0:
969 bs = BracketedSequence(
970 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
971 bs.first - self.caterpillar_length,
978 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
979 # 1s where tokens should be generated. The others are kept
982 def masked_inplace_autoregression(
986 forbidden_tokens=None,
987 deterministic_synthesis=False,
989 input = input_src.to(self.readout.f.weight.device)
990 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
991 to_generate = (ar_mask.sum(0) > 0).nonzero()
992 if to_generate.min() > 0:
994 BracketedSequence(input, 0, to_generate.min(), True)
995 ) # Needed to initialize the model's cache
996 for s in range(to_generate.min(), to_generate.max() + 1):
997 output = self(BracketedSequence(input, s, 1, s == 0)).x
998 logits = output[:, s]
999 if forbidden_tokens is not None:
1000 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
1001 if deterministic_synthesis:
1002 t_next = logits.argmax(1)
1004 dist = torch.distributions.categorical.Categorical(logits=logits)
1005 t_next = dist.sample()
1006 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1008 input_src.copy_(input)
1010 def reset_inner_loss(self):
1011 for m in self.modules():
1012 if m is not self and hasattr(m, "reset_inner_loss"):
1013 m.reset_inner_loss()
1015 def get_inner_loss(self):
1016 l = torch.tensor([0.0], device=self.readout.f.weight.device)
1017 for m in self.modules():
1018 if m is not self and hasattr(m, "get_inner_loss"):
1019 l += m.get_inner_loss()
1022 def record_attention(self, v=True):
1023 for m in self.modules():
1024 if isinstance(m, QKVAttention):
1025 m.record_attention = v
1027 def retrieve_attention(self):
1029 for m in self.modules():
1030 if isinstance(m, QKVAttention):
1035 ######################################################################
1037 if __name__ == "__main__":
1038 print("Basic check.")
1045 caterpillar_length=7,
1046 caterpillar_height=3,
1047 attention_dropout=0.0,
1050 m.reset_inner_loss()
1051 x = torch.randn(1, 21 + 2 * 7, 4)
1052 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1053 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1054 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1055 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1056 print((y1 - y2).abs().max())
1057 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1060 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1062 vocabulary_size = 128
1063 x = torch.randint(vocabulary_size, (6, 1024))
1066 vocabulary_size=vocabulary_size,
1082 # import torchvision.models as models
1083 # from torch.profiler import profile, record_function, ProfilerActivity
1085 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1086 # with record_function("model_inference"):
1090 start_time = time.perf_counter()
1092 model(BracketedSequence(x))
1093 duration = time.perf_counter() - start_time
1097 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1098 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1100 # print("##############################################################")
1101 # y2 = torch.randn_like(y1)
1102 # for s in range(x.size(1)):
1103 # z = model(BracketedSequence(x, s, 1))
1104 # y2[:, s : s + 1] = z.slice()
1106 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1108 ######################################################################