3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
24 # from blanket import blanket
28 ######################################################################
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
47 class BracketedSequence:
48 def __init__(self, x, first=None, nb=None, init_cache=None):
50 assert (first is None and nb is None and init_cache is None) or (
51 first is not None and nb is not None and init_cache is not None
54 self.first = 0 if first is None else first
55 self.nb = x.size(1) if nb is None else nb
56 self.init_cache = True if init_cache is None else init_cache
59 return self.x[:, self.first : self.first + self.nb]
62 return self.first == 0 and self.nb == self.x.size(1)
65 ######################################################################
68 class CacheWrapper(nn.Module):
69 def __init__(self, *f):
71 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
73 def forward(self, bs):
75 y = self.f(bs.slice())
76 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77 self.cache_y[:, bs.first : bs.first + bs.nb] = y
79 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80 assert bs.first + bs.nb <= self.cache_y.size(1)
81 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
83 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
86 ##############################
89 class WithResidual(nn.Module):
90 def __init__(self, *f):
92 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
94 def forward(self, bs):
95 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
98 ##############################
101 class AddPositionalEncoding(nn.Module):
102 def __init__(self, len_max):
104 self.len_max = len_max
106 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
108 def forward(self, bs):
110 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
113 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
118 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
120 self.cache_y = bs.x.new(bs.x.size())
122 self.cache_y[:, bs.first : bs.first + bs.nb] = (
123 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
126 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
131 # X is /.../xTxD A is /.../xT Y_init is /.../xD
134 def pscan_dim(A, X, Y_init, dim=-2):
136 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
138 A = A.reshape(a, T, *s[dim + 1 : -1])
139 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
142 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
144 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
146 Y = pscan.pscan(A, X, Y_init).reshape(s)
151 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
152 with torch.no_grad():
154 for t in range(X.size(dim) - 1, 0, -1):
155 delta = (grad_Y[t] - s_A) / A[t].grad
156 s_A += A[t].grad * delta
158 delta = (grad_Y[t] - s_X) / X[t].grad
159 s_X += X[t].grad * delta
163 def pscan_shape(A, X, Y_init):
165 A = A.reshape(-1, s[-2])
166 X = X.reshape(-1, s[-2], s[-1])
169 Y_init = X.new_zeros(X.size(0), s[-1])
171 Y_init = Y_init.reshape(-1, s[-1])
173 Y = pscan.pscan(A, X, Y_init).reshape(s)
178 def nsum_shape(X, Y_init):
180 X = X.reshape(-1, s[-2], s[-1]) # ntd
182 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
185 for k in range(X.size(1)):
187 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
190 return torch.cat(result, dim=1).reshape(s)
193 ##############################
196 class DumbRec(nn.Module):
204 attention_dropout=0.0,
212 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
214 self.nb_lines = nb_lines
215 self.attention_dropout = attention_dropout
217 self.k_star = randw(nb_lines, dim_qk)
219 self.w_qw = randw(nb_heads, dim_qk, dim_model)
220 self.w_qr = randw(nb_heads, dim_qk, dim_model)
221 # self.w_k = randw(nb_heads, dim_qk, dim_model)
222 self.w_v = randw(nb_heads, dim_v, dim_model)
223 self.w_o = randw(dim_v * nb_heads, dim_model)
225 def reset_inner_loss(self):
226 self.acc_attention = 0
229 def get_inner_loss(self):
230 warnings.warn("l2 regularization", RuntimeWarning)
231 return (self.acc_attention / self.acc_nb).pow(2).sum()
232 # return torch.tensor([0], device=self.w_qw.device)
234 def forward(self, bs):
235 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
238 self.rec_v = x_q.new_zeros(
239 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
241 # self.rec_k = x_q.new_zeros(
242 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
244 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
246 ######################################################################
249 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
251 warnings.warn("rotating key barrel", RuntimeWarning)
252 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
253 t_barrel = torch.arange(t0, t1, device=k_star.device)
254 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
256 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
258 k_star = k_star[l_barrel, t_barrel]
260 ######################################################################
261 # Compute the recurrent state
263 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
265 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
266 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
272 ) / math.sqrt(self.w_qw.size(1))
274 aw = aw.softmax(dim=2) # nhlt
277 self.acc_attention += aw.sum(dim=(0, 1, 3))
278 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
280 aw = F.dropout(aw, self.attention_dropout, self.training)
282 A = 1 - aw.sum(dim=1) # nlt
284 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
285 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
291 V0 = self.rec_v[:, :, t0 - 1]
292 # K0 = self.rec_k[:, :, t0 - 1]
294 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
295 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
297 ######################################################################
298 # compute the readout
300 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
305 # self.rec_k[:, :, t0:t1],
307 ) / math.sqrt(self.w_qr.size(1))
309 ar = ar.softmax(dim=2) # nhlt
311 ar = F.dropout(ar, self.attention_dropout, self.training)
316 self.rec_v[:, :, t0:t1],
319 self.cache_y[:, t0:t1] = y @ self.w_o
321 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
324 ##############################
327 class KVRec(nn.Module):
335 attention_dropout=0.0,
343 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
345 self.nb_lines = nb_lines
346 self.attention_dropout = attention_dropout
348 self.k_star = randw(nb_lines, dim_qk)
350 self.w_qw = randw(nb_heads, dim_qk, dim_model)
351 self.w_qr = randw(nb_heads, dim_qk, dim_model)
352 self.w_k = randw(nb_heads, dim_qk, dim_model)
353 self.w_v = randw(nb_heads, dim_v, dim_model)
354 self.w_o = randw(dim_v * nb_heads, dim_model)
356 def reset_inner_loss(self):
357 self.acc_attention = 0
360 def get_inner_loss(self):
361 warnings.warn("l2 regularization", RuntimeWarning)
362 return (self.acc_attention / self.acc_nb).pow(2).sum()
363 # return torch.tensor([0], device=self.w_qw.device)
364 # warnings.warn("side regularization", RuntimeWarning)
366 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
368 # return torch.tensor([0], device=self.w_qw.device)
370 def forward(self, bs):
371 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
374 self.rec_v = x_q.new_zeros(
375 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
377 self.rec_k = x_q.new_zeros(
378 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
380 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
382 ######################################################################
385 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
387 warnings.warn("rotating key barrel", RuntimeWarning)
388 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
389 t_barrel = torch.arange(t0, t1, device=k_star.device)
390 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
392 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
394 k_star = k_star[l_barrel, t_barrel]
396 ######################################################################
397 # Compute the recurrent state
399 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
401 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
402 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
408 ) / math.sqrt(self.w_qw.size(1))
410 aw = aw.softmax(dim=2) # nhlt
413 # We want all the memory lines to be used similarly
414 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
415 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
417 aw = F.dropout(aw, self.attention_dropout, self.training)
419 A = 1 - aw.sum(dim=1) # nlt
421 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
422 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
428 V0 = self.rec_v[:, :, t0 - 1]
429 K0 = self.rec_k[:, :, t0 - 1]
431 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
432 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
434 ######################################################################
435 # compute the readout
437 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
442 self.rec_k[:, :, t0:t1],
443 ) / math.sqrt(self.w_qr.size(1))
445 ar = ar.softmax(dim=2) # nhlt
447 ar = F.dropout(ar, self.attention_dropout, self.training)
452 self.rec_v[:, :, t0:t1],
455 self.cache_y[:, t0:t1] = y @ self.w_o
457 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
460 ##############################
463 # Returns a tensor with an additional index at rank win_dim, that move
464 # along the same dimension as dim, on a domain {0...win_size-1}, and
465 # dim is restricted on a domain reduced by win_size-1 values.
468 def moving_window(x, dim, win_dim, win_size):
469 size, stride = x.size(), x.stride()
470 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
471 size = size[:win_dim] + (win_size,) + size[win_dim:]
472 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
474 return x.as_strided(size=size, stride=stride)
477 ##############################
480 class Caterpillar(nn.Module):
489 attention_dropout=0.0,
496 warnings.warn("Caterpillar", RuntimeWarning)
498 def randw(*d, factor=1):
499 return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
501 self.caterpillar_length = caterpillar_length
502 self.caterpillar_height = caterpillar_height
503 self.attention_dropout = attention_dropout
505 ######################################################################
507 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
508 self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
510 self.w_K = randw(nb_heads, dim_qk, dim_model)
511 self.w_V = randw(nb_heads, dim_v, dim_model)
512 self.w_Q = randw(nb_heads, dim_qk, dim_model)
513 self.w_O = randw(dim_v * nb_heads, dim_model)
515 self.init_K_rec = randw(
520 self.init_V_rec = randw(
526 # def reset_inner_loss(self):
527 # self.acc_attention = 0
530 # def get_inner_loss(self):
531 # warnings.warn("l2 regularization", RuntimeWarning)
532 # return (self.acc_attention / self.acc_nb).pow(2).sum()
533 # return torch.tensor([0], device=self.w_Q.device)
535 def forward(self, bs):
536 # Dimensions to make the source a bit clearer, that's needed
538 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
543 DV = self.w_V.size(1)
544 DK = self.w_K.size(1)
545 DM = self.w_O.size(1)
546 R = self.caterpillar_height
547 L = self.caterpillar_length
550 t0 >= L and (t1 - t0) % L == 0
551 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
553 # We cache values to deal efficiently with auto-regression
556 self.rec_V = X.new_zeros(N, R, T, DV)
557 self.rec_K = X.new_zeros(N, R, T, DK)
558 # We start the recurrent sequences with optimizable
559 # initial values. No idea if it helps.
560 self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
561 self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
563 self.cache_Y = X.new_zeros(N, T, DM)
565 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
566 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
568 ######################################################################
569 # Compute the recurrent state
571 # This is the Gating sequence that modulates the storing of
572 # the new key and value in the R pairs of the current
573 # stack. There are R independent gating values, which means
574 # that the current K/V may be stored in multiple pairs of the
575 # recurrent state, or not at all.
578 torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
581 # Clip the gating to avoid values greater than 1 when several
582 # heads hit the same row
584 G = G / G.sum(1, keepdim=True).clamp(min=1)
586 ######################################################################
590 gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
591 gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
593 # We start from cached values, which matters in inference
595 init_rec_V = self.rec_V[:, :, t0 - L : t0]
596 init_rec_K = self.rec_K[:, :, t0 - L : t0]
598 # Here there is a trick: Since the stack at position t is
599 # computed by updating that at position t-L, the parallel
600 # scan operates with a period of L. To do so we split the
601 # sequence indexing in two axes, the second of size L, and
602 # run the parallel scan using the first as the sequence index.
604 A = A.unflatten(2, (-1, L))
605 gated_V = gated_V.unflatten(2, (-1, L))
606 gated_K = gated_K.unflatten(2, (-1, L))
608 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
609 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
611 self.rec_V[:, :, t0:t1] = next_V
612 self.rec_K[:, :, t0:t1] = next_K
614 ######################################################################
615 # compute the readout
617 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
621 # We build tensors NxHxTxRxL where N is the sample index, H
622 # the head, T the time, R the row in the caterpillar, and L
623 # the column in the caterpillar
625 windowed_V = moving_window(
626 self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
629 windowed_K = moving_window(
630 self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
633 # We have an attention score for each of the RxL values
641 # softmax can operate only on one dimension, hence the
644 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
646 ar = F.dropout(ar, self.attention_dropout, self.training)
648 # Compute the output for each head, flatten to concatenate
656 self.cache_Y[:, t0:t1] = Y @ self.w_O
658 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
661 ##############################
664 class QKVAttention(nn.Module):
673 attention_dropout=0.0,
680 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
683 self.horizon = horizon
684 self.attention_dropout = attention_dropout
685 self.record_attention = False
687 self.w_q = randw(nb_heads, dim_qk, dim_model)
688 self.w_k = randw(nb_heads, dim_qk, dim_model)
689 self.w_v = randw(nb_heads, dim_v, dim_model)
690 self.w_o = randw(dim_v * nb_heads, dim_model)
692 def forward(self, bs):
696 self.causal or bs.complete()
697 ), "Partial evaluation is only possible for causal models"
700 self.cache_k = x_q.new_zeros(
701 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
703 self.cache_v = x_q.new_zeros(
704 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
706 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
708 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
710 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
711 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
713 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
714 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
718 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
719 ) / math.sqrt(self.w_q.size(1))
723 self.cache_attzero = (
724 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
725 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
728 if self.horizon is not None:
729 self.cache_attzero = torch.logical_or(
731 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
732 >= torch.arange(x_q.size(1), device=q.device)[
740 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
747 if self.record_attention:
750 a = F.dropout(a, self.attention_dropout, self.training)
753 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
756 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
758 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
761 ##############################
764 class MyGPT(nn.Module):
774 caterpillar_height=None,
778 attention_layer="caterpillar",
784 assert attention_layer in {
790 }, f"Unknown attention operator {attention_layer}."
792 if attention_layer == "caterpillar" or attention_layer == "attcat":
793 assert nb_lines % caterpillar_height == 0
794 self.caterpillar_length = nb_lines // caterpillar_height
795 self.caterpillar_height = caterpillar_height
797 self.caterpillar_length = -1
798 self.caterpillar_height = -1
800 assert dim_model % nb_heads == 0
802 self.embedding = nn.Sequential(
803 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
804 AddPositionalEncoding(len_max),
810 if attention_layer == "mha":
812 CacheWrapper(nn.LayerNorm((dim_model,))),
816 dim_v=dim_model // nb_heads,
819 attention_dropout=dropout,
824 elif attention_layer == "dumbrec":
826 CacheWrapper(nn.LayerNorm((dim_model,))),
830 dim_v=dim_model // nb_heads,
833 attention_dropout=dropout,
838 elif attention_layer == "kvrec":
840 CacheWrapper(nn.LayerNorm((dim_model,))),
844 dim_v=dim_model // nb_heads,
847 attention_dropout=dropout,
852 elif attention_layer == "caterpillar":
854 CacheWrapper(nn.LayerNorm((dim_model,))),
858 dim_v=dim_model // nb_heads,
860 caterpillar_length=self.caterpillar_length,
861 caterpillar_height=self.caterpillar_height,
862 attention_dropout=dropout,
867 elif attention_layer == "attcat":
868 return nn.Sequential(
870 CacheWrapper(nn.LayerNorm((dim_model,))),
874 dim_v=dim_model // nb_heads,
877 horizon=self.caterpillar_length,
878 attention_dropout=dropout,
884 CacheWrapper(nn.LayerNorm((dim_model,))),
888 dim_v=dim_model // nb_heads,
890 caterpillar_length=self.caterpillar_length,
891 caterpillar_height=self.caterpillar_height,
892 attention_dropout=dropout,
899 raise ValueError(f"Unknown attention type {attention_layer}.")
901 for b in range(nb_blocks):
906 nn.LayerNorm((dim_model,)),
907 nn.Linear(in_features=dim_model, out_features=dim_hidden),
909 nn.Linear(in_features=dim_hidden, out_features=dim_model),
915 self.trunk = nn.Sequential(*trunk_blocks)
917 self.readout = CacheWrapper(
918 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
921 with torch.no_grad():
922 for m in self.modules():
923 if isinstance(m, nn.Embedding):
924 m.weight.normal_(mean=0, std=2e-2)
925 elif isinstance(m, nn.LayerNorm):
929 self.reset_inner_loss()
931 def forward(self, bs):
932 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
934 # To make the code simpler in the Caterpillar layer, we pad
935 # here. It's unclear if/how much it hurts computationaly by
936 # increasing the sequence length for the other layers
938 if self.caterpillar_length > 0:
940 if bs.nb % self.caterpillar_length > 0:
941 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
943 bs = BracketedSequence(
944 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
945 bs.first + self.caterpillar_length,
950 bs = self.embedding(bs)
952 bs = self.readout(bs)
954 if self.caterpillar_length > 0:
955 bs = BracketedSequence(
956 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
957 bs.first - self.caterpillar_length,
964 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
965 # 1s where tokens should be generated. The others are kept
968 def masked_inplace_autoregression(
972 forbidden_tokens=None,
973 deterministic_synthesis=False,
975 input = input_src.to(self.readout.f.weight.device)
976 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
977 to_generate = (ar_mask.sum(0) > 0).nonzero()
978 if to_generate.min() > 0:
980 BracketedSequence(input, 0, to_generate.min(), True)
981 ) # Needed to initialize the model's cache
982 for s in range(to_generate.min(), to_generate.max() + 1):
983 output = self(BracketedSequence(input, s, 1, s == 0)).x
984 logits = output[:, s]
985 if forbidden_tokens is not None:
986 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
987 if deterministic_synthesis:
988 t_next = logits.argmax(1)
990 dist = torch.distributions.categorical.Categorical(logits=logits)
991 t_next = dist.sample()
992 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
994 input_src.copy_(input)
996 def reset_inner_loss(self):
997 for m in self.modules():
998 if m is not self and hasattr(m, "reset_inner_loss"):
1001 def get_inner_loss(self):
1002 l = torch.tensor([0.0], device=self.readout.f.weight.device)
1003 for m in self.modules():
1004 if m is not self and hasattr(m, "get_inner_loss"):
1005 l += m.get_inner_loss()
1008 def record_attention(self, v=True):
1009 for m in self.modules():
1010 if isinstance(m, QKVAttention):
1011 m.record_attention = v
1013 def retrieve_attention(self):
1015 for m in self.modules():
1016 if isinstance(m, QKVAttention):
1021 ######################################################################
1023 if __name__ == "__main__":
1027 import matplotlib.pyplot as plt
1028 import matplotlib.collections as mc
1030 args = argparse.Namespace(
1031 gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1034 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1036 dim_model, dim_keys, nb_heads = 512, 64, 1
1039 caterpillar = Caterpillar(
1040 dim_model=dim_model,
1042 dim_v=dim_model // nb_heads,
1044 caterpillar_length=16,
1045 caterpillar_height=32,
1046 attention_dropout=dropout,
1051 dim_model=dim_model,
1053 dim_v=dim_model // nb_heads,
1056 attention_dropout=dropout,
1060 linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1062 x = torch.randn(1, 256, dim_model)
1067 ######################################################################
1070 fig.set_figheight(6)
1073 ax = fig.add_subplot(1, 1, 1)
1075 # ax.set_xlim(-1.5, 1.5)
1076 # ax.set_ylim(-1.5, 1.5)
1078 # ax.spines.right.set_visible(False)
1079 # ax.spines.top.set_visible(False)
1082 # t = np.arange(dt, 20.0, dt)
1083 # ax.semilogx(t, np.exp(-t / 5.0))
1085 ax.set_yscale("log")
1087 ######################################################################
1089 for label, model, thickness in [
1090 ("nn.Linear", linear, 0.2),
1091 ("mygpy.QKVAttention", qkv, 1),
1092 ("mygpt.Caterpillar", caterpillar, 2),
1094 y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1096 for n, p in [("input", x)] + list(model.named_parameters()):
1097 print(f"Processing {model}.{n}")
1099 for t in range(y.size(1)):
1101 for d in torch.randperm(y.size(2))[:8]:
1102 sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1103 assert not sg.isinf().any()
1104 assert not sg.isnan().any()
1105 data.append([t, sg.sum().item()])
1107 data = torch.tensor(data)
1108 # cx, cy = data[:, 0], data[:, 1]
1109 cy = data[:, 1].sort().values
1110 cx = torch.linspace(0, 1, cy.size(0))
1112 cx, cy, label=label + "." + n, linewidth=thickness
1113 ) # , color='gray', label='Input')
1115 # ax.legend(frameon=False, loc="top right")
1117 # Put a legend to the right of the current axis
1118 box = ax.get_position()
1119 ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1120 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1122 filename = "plot.pdf"
1123 print(f"saving {filename}")
1124 fig.savefig(filename, bbox_inches="tight")
1126 # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1127 # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1132 ######################################################################
1139 caterpillar_length=7,
1140 caterpillar_height=3,
1141 attention_dropout=0.0,
1144 m.reset_inner_loss()
1145 x = torch.randn(1, 21 + 2 * 7, 4)
1146 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1147 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1148 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1149 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1150 print((y1 - y2).abs().max())
1151 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1154 vocabulary_size = 128
1155 x = torch.randint(vocabulary_size, (6, 1024))
1158 vocabulary_size=vocabulary_size,
1174 # import torchvision.models as models
1175 # from torch.profiler import profile, record_function, ProfilerActivity
1177 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1178 # with record_function("model_inference"):
1182 start_time = time.perf_counter()
1184 model(BracketedSequence(x))
1185 duration = time.perf_counter() - start_time
1189 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1190 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1192 # print("##############################################################")
1193 # y2 = torch.randn_like(y1)
1194 # for s in range(x.size(1)):
1195 # z = model(BracketedSequence(x, s, 1))
1196 # y2[:, s : s + 1] = z.slice()
1198 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1200 ######################################################################