3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
24 # from blanket import blanket
28 ######################################################################
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
47 class BracketedSequence:
48 def __init__(self, x, first=None, nb=None, init_cache=None):
50 assert (first is None and nb is None and init_cache is None) or (
51 first is not None and nb is not None and init_cache is not None
54 self.first = 0 if first is None else first
55 self.nb = x.size(1) if nb is None else nb
56 self.init_cache = True if init_cache is None else init_cache
59 return self.x[:, self.first : self.first + self.nb]
62 return self.first == 0 and self.nb == self.x.size(1)
65 ######################################################################
68 class CacheWrapper(nn.Module):
69 def __init__(self, *f):
71 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
73 def forward(self, bs):
75 y = self.f(bs.slice())
76 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77 self.cache_y[:, bs.first : bs.first + bs.nb] = y
79 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80 assert bs.first + bs.nb <= self.cache_y.size(1)
81 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
83 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
86 ##############################
89 class NaNChecker(nn.Module):
90 def __init__(self, name):
94 def forward(self, bs):
95 x = bs.x if type(bs) is BracketedSequence else bs
96 assert not x.isnan().any(), f"${self.name} detected NaN"
97 assert not x.isinf().any(), f"${self.name} detected Inf"
101 class WithResidual(nn.Module):
102 def __init__(self, *f):
104 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
106 def forward(self, bs):
107 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
110 ##############################
113 class AddPositionalEncoding(nn.Module):
114 def __init__(self, len_max):
116 self.len_max = len_max
118 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
120 def forward(self, bs):
122 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
125 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
130 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
132 self.cache_y = bs.x.new(bs.x.size())
134 self.cache_y[:, bs.first : bs.first + bs.nb] = (
135 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
138 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
143 # X is /.../xTxD A is /.../xT Y_init is /.../xD
146 def pscan_dim(A, X, Y_init, dim=-2):
148 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
150 A = A.reshape(a, T, *s[dim + 1 : -1])
151 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
154 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
156 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
158 Y = pscan.pscan(A, X, Y_init).reshape(s)
163 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
164 with torch.no_grad():
166 for t in range(X.size(dim) - 1, 0, -1):
167 delta = (grad_Y[t] - s_A) / A[t].grad
168 s_A += A[t].grad * delta
170 delta = (grad_Y[t] - s_X) / X[t].grad
171 s_X += X[t].grad * delta
175 def pscan_shape(A, X, Y_init):
177 A = A.reshape(-1, s[-2])
178 X = X.reshape(-1, s[-2], s[-1])
181 Y_init = X.new_zeros(X.size(0), s[-1])
183 Y_init = Y_init.reshape(-1, s[-1])
185 Y = pscan.pscan(A, X, Y_init).reshape(s)
190 def nsum_shape(X, Y_init):
192 X = X.reshape(-1, s[-2], s[-1]) # ntd
194 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
197 for k in range(X.size(1)):
199 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
202 return torch.cat(result, dim=1).reshape(s)
205 ##############################
208 class DumbRec(nn.Module):
216 attention_dropout=0.0,
224 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
226 self.nb_lines = nb_lines
227 self.attention_dropout = attention_dropout
229 self.k_star = randw(nb_lines, dim_qk)
231 self.w_qw = randw(nb_heads, dim_qk, dim_model)
232 self.w_qr = randw(nb_heads, dim_qk, dim_model)
233 self.w_v = randw(nb_heads, dim_v, dim_model)
234 self.w_o = randw(dim_v * nb_heads, dim_model)
236 def forward(self, bs):
237 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
240 self.rec_v = x_q.new_zeros(
241 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
243 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
245 ######################################################################
246 # Compute the recurrent state
248 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
250 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
252 aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
256 aw = aw.softmax(dim=2) # nhlt
258 aw = F.dropout(aw, self.attention_dropout, self.training)
260 A = 1 - aw.sum(dim=1) # nlt
262 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
267 V0 = self.rec_v[:, :, t0 - 1]
269 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
271 ######################################################################
272 # compute the readout
274 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
280 ) / math.sqrt(self.w_qr.size(1))
282 ar = ar.softmax(dim=2) # nhlt
284 ar = F.dropout(ar, self.attention_dropout, self.training)
289 self.rec_v[:, :, t0:t1],
292 self.cache_y[:, t0:t1] = y @ self.w_o
294 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
297 ##############################
300 class KVRec(nn.Module):
308 attention_dropout=0.0,
316 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
318 self.nb_lines = nb_lines
319 self.attention_dropout = attention_dropout
321 self.k_star = randw(nb_lines, dim_qk)
323 self.w_qw = randw(nb_heads, dim_qk, dim_model)
324 self.w_qr = randw(nb_heads, dim_qk, dim_model)
325 self.w_k = randw(nb_heads, dim_qk, dim_model)
326 self.w_v = randw(nb_heads, dim_v, dim_model)
327 self.w_o = randw(dim_v * nb_heads, dim_model)
329 def reset_inner_loss(self):
330 self.acc_attention = 0
333 def get_inner_loss(self):
334 # warnings.warn("l2 regularization", RuntimeWarning)
335 # return (self.acc_attention / self.acc_nb).pow(2).sum()
336 return torch.tensor([0], device=self.w_qw.device)
337 # warnings.warn("side regularization", RuntimeWarning)
339 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
341 # return torch.tensor([0], device=self.w_qw.device)
343 def forward(self, bs):
344 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
347 self.rec_v = x_q.new_zeros(
348 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
350 self.rec_k = x_q.new_zeros(
351 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
353 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
355 ######################################################################
358 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
360 # warnings.warn("rotating key barrel", RuntimeWarning)
361 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
362 t_barrel = torch.arange(t0, t1, device=k_star.device)
363 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
365 torch.arange(k_star.size(0), device=k_star.device)[:, None] # + t_barrel
367 k_star = k_star[l_barrel, t_barrel]
369 ######################################################################
370 # Compute the recurrent state
372 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
374 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
375 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
381 ) / math.sqrt(self.w_qw.size(1))
383 aw = aw.softmax(dim=2) # nhlt
386 # We want all the memory lines to be used similarly
387 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
388 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
390 aw = F.dropout(aw, self.attention_dropout, self.training)
392 A = 1 - aw.sum(dim=1) # nlt
394 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
395 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
401 V0 = self.rec_v[:, :, t0 - 1]
402 K0 = self.rec_k[:, :, t0 - 1]
404 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
405 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
407 ######################################################################
408 # compute the readout
410 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
415 self.rec_k[:, :, t0:t1],
416 ) / math.sqrt(self.w_qr.size(1))
418 ar = ar.softmax(dim=2) # nhlt
420 ar = F.dropout(ar, self.attention_dropout, self.training)
425 self.rec_v[:, :, t0:t1],
428 self.cache_y[:, t0:t1] = y @ self.w_o
430 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
433 ##############################
436 # Returns a tensor with an additional index at rank win_dim, that move
437 # along the same dimension as dim, on a domain {0...win_size-1}, and
438 # dim is restricted on a domain reduced by win_size-1 values.
441 def moving_window(x, dim, win_dim, win_size):
442 size, stride = x.size(), x.stride()
443 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
444 size = size[:win_dim] + (win_size,) + size[win_dim:]
445 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
447 return x.as_strided(size=size, stride=stride)
450 ##############################
453 class Caterpillar(nn.Module):
462 attention_dropout=0.0,
469 warnings.warn("Caterpillar", RuntimeWarning)
471 def randw(*d, factor=1):
472 return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
474 self.caterpillar_length = caterpillar_length
475 self.caterpillar_height = caterpillar_height
476 self.attention_dropout = attention_dropout
478 ######################################################################
480 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
481 self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
483 self.w_K = randw(nb_heads, dim_qk, dim_model)
484 self.w_V = randw(nb_heads, dim_v, dim_model)
485 self.w_Q = randw(nb_heads, dim_qk, dim_model)
486 self.w_O = randw(dim_v * nb_heads, dim_model)
488 self.init_K_rec = randw(
493 self.init_V_rec = randw(
499 # def reset_inner_loss(self):
500 # self.acc_attention = 0
503 # def get_inner_loss(self):
504 # warnings.warn("l2 regularization", RuntimeWarning)
505 # return (self.acc_attention / self.acc_nb).pow(2).sum()
506 # return torch.tensor([0], device=self.w_Q.device)
508 def forward(self, bs):
509 # Dimensions to make the source a bit clearer, that's needed
511 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
516 DV = self.w_V.size(1)
517 DK = self.w_K.size(1)
518 DM = self.w_O.size(1)
519 R = self.caterpillar_height
520 L = self.caterpillar_length
523 t0 >= L and (t1 - t0) % L == 0
524 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
526 # We cache values to deal efficiently with auto-regression
529 self.rec_V = X.new_zeros(N, R, T, DV)
530 self.rec_K = X.new_zeros(N, R, T, DK)
531 # We start the recurrent sequences with optimizable
532 # initial values. No idea if it helps.
533 self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
534 self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
536 self.cache_Y = X.new_zeros(N, T, DM)
538 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
539 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
541 ######################################################################
542 # Compute the recurrent state
544 # This is the Gating sequence that modulates the storing of
545 # the new key and value in the R pairs of the current
546 # stack. There are R independent gating values, which means
547 # that the current K/V may be stored in multiple pairs of the
548 # recurrent state, or not at all.
551 torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
554 # Clip the gating to avoid values greater than 1 when several
555 # heads hit the same row
557 G = G / G.sum(1, keepdim=True).clamp(min=1)
559 ######################################################################
563 gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
564 gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
566 # We start from cached values, which matters in inference
568 init_rec_V = self.rec_V[:, :, t0 - L : t0]
569 init_rec_K = self.rec_K[:, :, t0 - L : t0]
571 # Here there is a trick: Since the stack at position t is
572 # computed by updating that at position t-L, the parallel
573 # scan operates with a period of L. To do so we split the
574 # sequence indexing in two axes, the second of size L, and
575 # run the parallel scan using the first as the sequence index.
577 A = A.unflatten(2, (-1, L))
578 gated_V = gated_V.unflatten(2, (-1, L))
579 gated_K = gated_K.unflatten(2, (-1, L))
581 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
582 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
584 self.rec_V[:, :, t0:t1] = next_V
585 self.rec_K[:, :, t0:t1] = next_K
587 ######################################################################
588 # compute the readout
590 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
594 # We build tensors NxHxTxRxL where N is the sample index, H
595 # the head, T the time, R the row in the caterpillar, and L
596 # the column in the caterpillar
598 windowed_V = moving_window(
599 self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
602 windowed_K = moving_window(
603 self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
606 # We have an attention score for each of the RxL values
614 # softmax can operate only on one dimension, hence the
617 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
619 ar = F.dropout(ar, self.attention_dropout, self.training)
621 # Compute the output for each head, flatten to concatenate
629 self.cache_Y[:, t0:t1] = Y @ self.w_O
631 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
634 ##############################
637 class QKVAttention(nn.Module):
646 attention_dropout=0.0,
653 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
656 self.horizon = horizon
657 self.attention_dropout = attention_dropout
658 self.record_attention = False
660 self.w_q = randw(nb_heads, dim_qk, dim_model)
661 self.w_k = randw(nb_heads, dim_qk, dim_model)
662 self.w_v = randw(nb_heads, dim_v, dim_model)
663 self.w_o = randw(dim_v * nb_heads, dim_model)
665 def forward(self, bs):
669 self.causal or bs.complete()
670 ), "Partial evaluation is only possible for causal models"
673 self.cache_k = x_q.new_zeros(
674 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
676 self.cache_v = x_q.new_zeros(
677 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
679 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
681 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
683 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
684 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
686 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
687 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
691 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
692 ) / math.sqrt(self.w_q.size(1))
696 self.cache_attzero = (
697 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
698 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
701 if self.horizon is not None:
702 self.cache_attzero = torch.logical_or(
704 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
705 >= torch.arange(x_q.size(1), device=q.device)[
713 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
720 if self.record_attention:
723 a = F.dropout(a, self.attention_dropout, self.training)
726 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
729 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
731 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
734 ##############################
737 class MyGPT(nn.Module):
747 caterpillar_height=None,
751 attention_layer="caterpillar",
757 self.vocabulary_size = vocabulary_size
759 assert attention_layer in {
765 }, f"Unknown attention operator {attention_layer}."
767 if attention_layer == "caterpillar" or attention_layer == "attcat":
768 assert nb_lines % caterpillar_height == 0
769 self.caterpillar_length = nb_lines // caterpillar_height
770 self.caterpillar_height = caterpillar_height
772 self.caterpillar_length = -1
773 self.caterpillar_height = -1
775 assert dim_model % nb_heads == 0
777 self.embedding = nn.Sequential(
778 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
779 AddPositionalEncoding(len_max),
785 if attention_layer == "mha":
787 CacheWrapper(nn.LayerNorm((dim_model,))),
791 dim_v=dim_model // nb_heads,
794 attention_dropout=dropout,
799 elif attention_layer == "dumbrec":
801 CacheWrapper(nn.LayerNorm((dim_model,))),
805 dim_v=dim_model // nb_heads,
808 attention_dropout=dropout,
813 elif attention_layer == "kvrec":
815 CacheWrapper(nn.LayerNorm((dim_model,))),
819 dim_v=dim_model // nb_heads,
822 attention_dropout=dropout,
827 elif attention_layer == "caterpillar":
829 CacheWrapper(nn.LayerNorm((dim_model,))),
833 dim_v=dim_model // nb_heads,
835 caterpillar_length=self.caterpillar_length,
836 caterpillar_height=self.caterpillar_height,
837 attention_dropout=dropout,
842 elif attention_layer == "attcat":
843 return nn.Sequential(
845 CacheWrapper(nn.LayerNorm((dim_model,))),
849 dim_v=dim_model // nb_heads,
852 horizon=self.caterpillar_length,
853 attention_dropout=dropout,
859 CacheWrapper(nn.LayerNorm((dim_model,))),
863 dim_v=dim_model // nb_heads,
865 caterpillar_length=self.caterpillar_length,
866 caterpillar_height=self.caterpillar_height,
867 attention_dropout=dropout,
874 raise ValueError(f"Unknown attention type {attention_layer}.")
876 for b in range(nb_blocks):
881 nn.LayerNorm((dim_model,)),
882 nn.Linear(in_features=dim_model, out_features=dim_hidden),
884 nn.Linear(in_features=dim_hidden, out_features=dim_model),
890 self.trunk = nn.Sequential(*trunk_blocks)
892 self.readout = CacheWrapper(
893 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
896 with torch.no_grad():
897 for m in self.modules():
898 if isinstance(m, nn.Embedding):
899 m.weight.normal_(mean=0, std=2e-2)
900 elif isinstance(m, nn.LayerNorm):
904 self.reset_inner_loss()
906 def forward(self, bs):
907 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
909 # To make the code simpler in the Caterpillar layer, we pad
910 # here. It's unclear if/how much it hurts computationaly by
911 # increasing the sequence length for the other layers
913 if self.caterpillar_length > 0:
915 if bs.nb % self.caterpillar_length > 0:
916 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
918 bs = BracketedSequence(
919 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
920 bs.first + self.caterpillar_length,
925 bs = self.embedding(bs)
927 bs = self.readout(bs)
929 if self.caterpillar_length > 0:
930 bs = BracketedSequence(
931 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
932 bs.first - self.caterpillar_length,
939 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
940 # 1s where tokens should be generated. The others are kept
943 def masked_inplace_autoregression(
947 forbidden_tokens=None,
948 deterministic_synthesis=False,
950 input = input_src.to(self.readout.f.weight.device)
951 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
952 to_generate = (ar_mask.sum(0) > 0).nonzero()
953 if to_generate.min() > 0:
955 BracketedSequence(input, 0, to_generate.min(), True)
956 ) # Needed to initialize the model's cache
957 for s in range(to_generate.min(), to_generate.max() + 1):
958 output = self(BracketedSequence(input, s, 1, s == 0)).x
959 logits = output[:, s]
960 if forbidden_tokens is not None:
961 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
962 if deterministic_synthesis:
963 t_next = logits.argmax(1)
965 dist = torch.distributions.categorical.Categorical(logits=logits)
966 t_next = dist.sample()
967 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
969 input_src.copy_(input)
971 def reset_inner_loss(self):
972 for m in self.modules():
973 if m is not self and hasattr(m, "reset_inner_loss"):
976 def get_inner_loss(self):
977 l = torch.tensor([0.0], device=self.readout.f.weight.device)
978 for m in self.modules():
979 if m is not self and hasattr(m, "get_inner_loss"):
980 l += m.get_inner_loss()
983 def record_attention(self, v=True):
984 for m in self.modules():
985 if isinstance(m, QKVAttention):
986 m.record_attention = v
988 def retrieve_attention(self):
990 for m in self.modules():
991 if isinstance(m, QKVAttention):
996 ######################################################################
998 if __name__ == "__main__":
1002 import matplotlib.pyplot as plt
1003 import matplotlib.collections as mc
1005 args = argparse.Namespace(
1006 gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1009 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1011 dim_model, dim_keys, nb_heads = 512, 64, 1
1014 caterpillar = Caterpillar(
1015 dim_model=dim_model,
1017 dim_v=dim_model // nb_heads,
1019 caterpillar_length=16,
1020 caterpillar_height=32,
1021 attention_dropout=dropout,
1026 dim_model=dim_model,
1028 dim_v=dim_model // nb_heads,
1031 attention_dropout=dropout,
1035 linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1037 x = torch.randn(1, 256, dim_model)
1042 ######################################################################
1045 fig.set_figheight(6)
1048 ax = fig.add_subplot(1, 1, 1)
1050 # ax.set_xlim(-1.5, 1.5)
1051 # ax.set_ylim(-1.5, 1.5)
1053 # ax.spines.right.set_visible(False)
1054 # ax.spines.top.set_visible(False)
1057 # t = np.arange(dt, 20.0, dt)
1058 # ax.semilogx(t, np.exp(-t / 5.0))
1060 ax.set_yscale("log")
1062 ######################################################################
1064 for label, model, thickness in [
1065 ("nn.Linear", linear, 0.2),
1066 ("mygpy.QKVAttention", qkv, 1),
1067 ("mygpt.Caterpillar", caterpillar, 2),
1069 y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1071 for n, p in [("input", x)] + list(model.named_parameters()):
1072 print(f"Processing {model}.{n}")
1074 for t in range(y.size(1)):
1076 for d in torch.randperm(y.size(2))[:8]:
1077 sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1078 assert not sg.isinf().any()
1079 assert not sg.isnan().any()
1080 data.append([t, sg.sum().item()])
1082 data = torch.tensor(data)
1083 # cx, cy = data[:, 0], data[:, 1]
1084 cy = data[:, 1].sort().values
1085 cx = torch.linspace(0, 1, cy.size(0))
1087 cx, cy, label=label + "." + n, linewidth=thickness
1088 ) # , color='gray', label='Input')
1090 # ax.legend(frameon=False, loc="top right")
1092 # Put a legend to the right of the current axis
1093 box = ax.get_position()
1094 ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1095 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1097 filename = "plot.pdf"
1098 print(f"saving {filename}")
1099 fig.savefig(filename, bbox_inches="tight")
1101 # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1102 # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1107 ######################################################################
1114 caterpillar_length=7,
1115 caterpillar_height=3,
1116 attention_dropout=0.0,
1119 m.reset_inner_loss()
1120 x = torch.randn(1, 21 + 2 * 7, 4)
1121 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1122 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1123 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1124 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1125 print((y1 - y2).abs().max())
1126 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1129 vocabulary_size = 128
1130 x = torch.randint(vocabulary_size, (6, 1024))
1133 vocabulary_size=vocabulary_size,
1149 # import torchvision.models as models
1150 # from torch.profiler import profile, record_function, ProfilerActivity
1152 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1153 # with record_function("model_inference"):
1157 start_time = time.perf_counter()
1159 model(BracketedSequence(x))
1160 duration = time.perf_counter() - start_time
1164 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1165 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1167 # print("##############################################################")
1168 # y2 = torch.randn_like(y1)
1169 # for s in range(x.size(1)):
1170 # z = model(BracketedSequence(x, s, 1))
1171 # y2[:, s : s + 1] = z.slice()
1173 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1175 ######################################################################