3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
24 # from blanket import blanket
28 ######################################################################
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
47 class BracketedSequence:
48 def __init__(self, x, first=None, nb=None, init_cache=None):
50 assert (first is None and nb is None and init_cache is None) or (
51 first is not None and nb is not None and init_cache is not None
54 self.first = 0 if first is None else first
55 self.nb = x.size(1) if nb is None else nb
56 self.init_cache = True if init_cache is None else init_cache
59 return self.x[:, self.first : self.first + self.nb]
62 return self.first == 0 and self.nb == self.x.size(1)
65 ######################################################################
68 class CacheWrapper(nn.Module):
69 def __init__(self, *f):
71 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
73 def forward(self, bs):
75 y = self.f(bs.slice())
76 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77 self.cache_y[:, bs.first : bs.first + bs.nb] = y
79 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80 assert bs.first + bs.nb <= self.cache_y.size(1)
81 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
83 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
86 ##############################
89 class WithResidual(nn.Module):
90 def __init__(self, *f):
92 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
94 def forward(self, bs):
95 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
98 ##############################
101 class AddPositionalEncoding(nn.Module):
102 def __init__(self, len_max):
104 self.len_max = len_max
106 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
108 def forward(self, bs):
110 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
113 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
118 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
120 self.cache_y = bs.x.new(bs.x.size())
122 self.cache_y[:, bs.first : bs.first + bs.nb] = (
123 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
126 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
131 # X is /.../xTxD A is /.../xT Y_init is /.../xD
134 def pscan_dim(A, X, Y_init, dim=-2):
136 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
138 A = A.reshape(a, T, *s[dim + 1 : -1])
139 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
142 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
144 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
146 Y = pscan.pscan(A, X, Y_init).reshape(s)
151 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
152 with torch.no_grad():
154 for t in range(X.size(dim) - 1, 0, -1):
155 delta = (grad_Y[t] - s_A) / A[t].grad
156 s_A += A[t].grad * delta
158 delta = (grad_Y[t] - s_X) / X[t].grad
159 s_X += X[t].grad * delta
163 def pscan_shape(A, X, Y_init):
165 A = A.reshape(-1, s[-2])
166 X = X.reshape(-1, s[-2], s[-1])
169 Y_init = X.new_zeros(X.size(0), s[-1])
171 Y_init = Y_init.reshape(-1, s[-1])
173 Y = pscan.pscan(A, X, Y_init).reshape(s)
178 def nsum_shape(X, Y_init):
180 X = X.reshape(-1, s[-2], s[-1]) # ntd
182 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
185 for k in range(X.size(1)):
187 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
190 return torch.cat(result, dim=1).reshape(s)
193 ##############################
196 class DumbRec(nn.Module):
204 attention_dropout=0.0,
212 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
214 self.nb_lines = nb_lines
215 self.attention_dropout = attention_dropout
217 self.k_star = randw(nb_lines, dim_qk)
219 self.w_qw = randw(nb_heads, dim_qk, dim_model)
220 self.w_qr = randw(nb_heads, dim_qk, dim_model)
221 # self.w_k = randw(nb_heads, dim_qk, dim_model)
222 self.w_v = randw(nb_heads, dim_v, dim_model)
223 self.w_o = randw(dim_v * nb_heads, dim_model)
225 def reset_inner_loss(self):
226 self.acc_attention = 0
229 def get_inner_loss(self):
230 warnings.warn("l2 regularization", RuntimeWarning)
231 return (self.acc_attention / self.acc_nb).pow(2).sum()
232 # return torch.tensor([0], device=self.w_qw.device)
234 def forward(self, bs):
235 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
238 self.rec_v = x_q.new_zeros(
239 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
241 # self.rec_k = x_q.new_zeros(
242 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
244 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
246 ######################################################################
249 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
251 warnings.warn("rotating key barrel", RuntimeWarning)
252 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
253 t_barrel = torch.arange(t0, t1, device=k_star.device)
254 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
256 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
258 k_star = k_star[l_barrel, t_barrel]
260 ######################################################################
261 # Compute the recurrent state
263 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
265 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
266 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
272 ) / math.sqrt(self.w_qw.size(1))
274 aw = aw.softmax(dim=2) # nhlt
277 self.acc_attention += aw.sum(dim=(0, 1, 3))
278 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
280 aw = F.dropout(aw, self.attention_dropout, self.training)
282 A = 1 - aw.sum(dim=1) # nlt
284 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
285 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
291 V0 = self.rec_v[:, :, t0 - 1]
292 # K0 = self.rec_k[:, :, t0 - 1]
294 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
295 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
297 ######################################################################
298 # compute the readout
300 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
305 # self.rec_k[:, :, t0:t1],
307 ) / math.sqrt(self.w_qr.size(1))
309 ar = ar.softmax(dim=2) # nhlt
311 ar = F.dropout(ar, self.attention_dropout, self.training)
316 self.rec_v[:, :, t0:t1],
319 self.cache_y[:, t0:t1] = y @ self.w_o
321 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
324 ##############################
327 class KVRec(nn.Module):
335 attention_dropout=0.0,
343 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
345 self.nb_lines = nb_lines
346 self.attention_dropout = attention_dropout
348 self.k_star = randw(nb_lines, dim_qk)
350 self.w_qw = randw(nb_heads, dim_qk, dim_model)
351 self.w_qr = randw(nb_heads, dim_qk, dim_model)
352 self.w_k = randw(nb_heads, dim_qk, dim_model)
353 self.w_v = randw(nb_heads, dim_v, dim_model)
354 self.w_o = randw(dim_v * nb_heads, dim_model)
356 def reset_inner_loss(self):
357 self.acc_attention = 0
360 def get_inner_loss(self):
361 warnings.warn("l2 regularization", RuntimeWarning)
362 return (self.acc_attention / self.acc_nb).pow(2).sum()
363 # return torch.tensor([0], device=self.w_qw.device)
364 # warnings.warn("side regularization", RuntimeWarning)
366 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
368 # return torch.tensor([0], device=self.w_qw.device)
370 def forward(self, bs):
371 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
374 self.rec_v = x_q.new_zeros(
375 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
377 self.rec_k = x_q.new_zeros(
378 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
380 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
382 ######################################################################
385 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
387 warnings.warn("rotating key barrel", RuntimeWarning)
388 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
389 t_barrel = torch.arange(t0, t1, device=k_star.device)
390 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
392 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
394 k_star = k_star[l_barrel, t_barrel]
396 ######################################################################
397 # Compute the recurrent state
399 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
401 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
402 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
408 ) / math.sqrt(self.w_qw.size(1))
410 aw = aw.softmax(dim=2) # nhlt
413 # We want all the memory lines to be used similarly
414 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
415 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
417 aw = F.dropout(aw, self.attention_dropout, self.training)
419 A = 1 - aw.sum(dim=1) # nlt
421 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
422 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
428 V0 = self.rec_v[:, :, t0 - 1]
429 K0 = self.rec_k[:, :, t0 - 1]
431 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
432 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
434 ######################################################################
435 # compute the readout
437 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
442 self.rec_k[:, :, t0:t1],
443 ) / math.sqrt(self.w_qr.size(1))
445 ar = ar.softmax(dim=2) # nhlt
447 ar = F.dropout(ar, self.attention_dropout, self.training)
452 self.rec_v[:, :, t0:t1],
455 self.cache_y[:, t0:t1] = y @ self.w_o
457 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
460 ##############################
463 # Returns a tensor with an additional index at rank win_dim, that move
464 # along the same dimension as dim, on a domain {0...win_size-1}, and
465 # dim is restricted on a domain reduced by win_size-1 values.
468 def moving_window(x, dim, win_dim, win_size):
469 size, stride = x.size(), x.stride()
470 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
471 size = size[:win_dim] + (win_size,) + size[win_dim:]
472 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
474 return x.as_strided(size=size, stride=stride)
477 ##############################
480 class Caterpillar(nn.Module):
489 attention_dropout=0.0,
496 warnings.warn("Caterpillar", RuntimeWarning)
498 def randw(*d, factor=1):
499 return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
501 self.caterpillar_length = caterpillar_length
502 self.caterpillar_height = caterpillar_height
503 self.attention_dropout = attention_dropout
505 self.gate_dropout_proba = args.gate_dropout_proba
506 self.gate_dropout_sync = args.gate_dropout_sync
507 self.gate_dropout_replace = args.gate_dropout_replace
509 ######################################################################
511 self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3)
512 self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
514 self.w_K = randw(nb_heads, dim_qk, dim_model)
515 self.w_V = randw(nb_heads, dim_v, dim_model)
516 self.w_Q = randw(nb_heads, dim_qk, dim_model)
517 self.w_O = randw(dim_v * nb_heads, dim_model)
519 self.init_K_rec = randw(
524 self.init_V_rec = randw(
530 # def reset_inner_loss(self):
531 # self.acc_attention = 0
534 # def get_inner_loss(self):
535 # warnings.warn("l2 regularization", RuntimeWarning)
536 # return (self.acc_attention / self.acc_nb).pow(2).sum()
537 # return torch.tensor([0], device=self.w_Q.device)
539 def forward(self, bs):
540 # Dimensions to make the source a bit clearer, that's needed
542 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
547 DV = self.w_V.size(1)
548 DK = self.w_K.size(1)
549 DM = self.w_O.size(1)
550 R = self.caterpillar_height
551 L = self.caterpillar_length
554 t0 >= L and (t1 - t0) % L == 0
555 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
557 # We cache values to deal efficiently with auto-regression
560 self.rec_V = X.new_zeros(N, R, T, DV)
561 self.rec_K = X.new_zeros(N, R, T, DK)
562 # We start the recurrent sequences with optimizable
563 # initial values. No idea if it helps.
564 self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
565 self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
567 self.cache_Y = X.new_zeros(N, T, DM)
569 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
570 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
572 # V, K = blanket(V), blanket(K)
574 ######################################################################
575 # Compute the recurrent state
577 # This is the Gating sequence that modulates the storing of
578 # the new key and value in the R pairs of the current
579 # stack. There are R independent gating values, which means
580 # that the current K/V may be stored in multiple pairs of the
581 # recurrent state, or not at all.
584 torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
587 # Clip the gating to avoid values greater than 1 when several
588 # heads hit the same row
590 G = G / G.sum(1, keepdim=True).clamp(min=1)
592 # G_star = (1 - G).log().sum(1, keepdim=True).exp()
594 ######################################################################
596 def recurrence(G, V, K):
597 # We prepare the arguments for the parallel scan
601 gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
602 gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
604 # We start from cached values, which matters in inference
606 init_rec_V = self.rec_V[:, :, t0 - L : t0]
607 init_rec_K = self.rec_K[:, :, t0 - L : t0]
609 # Here there is a trick: Since the stack at position t is
610 # computed by updating that at position t-L, the parallel
611 # scan operates with a period of L. To do so we split the
612 # sequence indexing in two axes, the second of size L, and
613 # run the parallel scan using the first as the sequence index.
615 A = A.unflatten(2, (-1, L))
616 gated_V = gated_V.unflatten(2, (-1, L))
617 gated_K = gated_K.unflatten(2, (-1, L))
619 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
620 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
622 return next_V, next_K
624 #################################################################
626 next_V, next_K = recurrence(G, V, K)
628 if self.training and self.gate_dropout_proba > 0.0:
629 # G is NxHxRxT where r is the caterpillar's row.
631 warnings.warn("gate dropout", RuntimeWarning)
633 if self.gate_dropout_sync:
634 shape_kill = (N, 1, 1)
636 shape_kill = (N, H, R)
638 # Pick a point in each of the NxHxR timeline and set this
639 # entry and the following to 1
641 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
645 # Keep these mask for only some of the NxHxR
647 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
650 # The coefficient to keep are the complementary
653 masked_next_V, masked_next_K = recurrence(G * mask, V, K)
655 if self.gate_dropout_replace:
656 next_V = next_V.detach()
657 next_K = next_K.detach()
659 warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
661 next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
662 1 - self.gate_dropout_proba
664 next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
665 1 - self.gate_dropout_proba
668 self.rec_V[:, :, t0:t1] = next_V
669 self.rec_K[:, :, t0:t1] = next_K
671 ######################################################################
672 # compute the readout
674 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
678 # We build tensors NxHxTxRxL where N is the sample index, H
679 # the head, T the time, R the row in the caterpillar, and L
680 # the column in the caterpillar
682 windowed_V = moving_window(
683 self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
686 windowed_K = moving_window(
687 self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
690 # We have an attention score for each of the RxL values
698 # softmax can operate only on one dimension, hence the
701 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
703 ar = F.dropout(ar, self.attention_dropout, self.training)
705 # Compute the output for each head, flatten to concatenate
713 # Compute the final output
717 self.cache_Y[:, t0:t1] = Y @ self.w_O
719 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
722 ##############################
725 class QKVAttention(nn.Module):
733 attention_dropout=0.0,
740 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
743 self.attention_dropout = attention_dropout
744 self.record_attention = False
746 self.w_q = randw(nb_heads, dim_qk, dim_model)
747 self.w_k = randw(nb_heads, dim_qk, dim_model)
748 self.w_v = randw(nb_heads, dim_v, dim_model)
749 self.w_o = randw(dim_v * nb_heads, dim_model)
751 def forward(self, bs):
755 self.causal or bs.complete()
756 ), "Partial evaluation is only possible for causal models"
759 self.cache_k = x_q.new_zeros(
760 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
762 self.cache_v = x_q.new_zeros(
763 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
765 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
767 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
769 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
770 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
772 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
773 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
777 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
778 ) / math.sqrt(self.w_q.size(1))
782 self.cache_attzero = (
783 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
784 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
788 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
795 if self.record_attention:
798 a = F.dropout(a, self.attention_dropout, self.training)
801 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
804 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
806 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
809 ##############################
812 class MyGPT(nn.Module):
822 caterpillar_height=None,
826 attention_layer="caterpillar",
832 assert attention_layer in {
837 }, f"Unknown attention operator {attention_layer}."
839 if attention_layer == "caterpillar":
840 assert nb_lines % caterpillar_height == 0
841 self.caterpillar_length = nb_lines // caterpillar_height
842 self.caterpillar_height = caterpillar_height
844 self.caterpillar_length = -1
845 self.caterpillar_height = -1
847 assert dim_model % nb_heads == 0
849 self.embedding = nn.Sequential(
850 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
851 AddPositionalEncoding(len_max),
857 if attention_layer == "mha":
861 dim_v=dim_model // nb_heads,
864 attention_dropout=dropout,
868 elif attention_layer == "dumbrec":
872 dim_v=dim_model // nb_heads,
875 attention_dropout=dropout,
879 elif attention_layer == "kvrec":
883 dim_v=dim_model // nb_heads,
886 attention_dropout=dropout,
890 elif attention_layer == "caterpillar":
894 dim_v=dim_model // nb_heads,
896 caterpillar_length=self.caterpillar_length,
897 caterpillar_height=self.caterpillar_height,
898 attention_dropout=dropout,
903 raise ValueError(f"Unknown attention type {attention_layer}.")
905 for b in range(nb_blocks):
908 CacheWrapper(nn.LayerNorm((dim_model,))),
913 nn.LayerNorm((dim_model,)),
914 nn.Linear(in_features=dim_model, out_features=dim_hidden),
916 nn.Linear(in_features=dim_hidden, out_features=dim_model),
922 self.trunk = nn.Sequential(*trunk_blocks)
924 self.readout = CacheWrapper(
925 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
928 with torch.no_grad():
929 for m in self.modules():
930 if isinstance(m, nn.Embedding):
931 m.weight.normal_(mean=0, std=2e-2)
932 elif isinstance(m, nn.LayerNorm):
936 self.reset_inner_loss()
938 def forward(self, bs):
939 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
941 # To make the code simpler in the Caterpillar layer, we pad
942 # here. It's unclear if/how much it hurts computationaly by
943 # increasing the sequence length for the other layers
945 if self.caterpillar_length > 0:
947 if bs.nb % self.caterpillar_length > 0:
948 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
950 bs = BracketedSequence(
951 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
952 bs.first + self.caterpillar_length,
957 bs = self.embedding(bs)
959 bs = self.readout(bs)
961 if self.caterpillar_length > 0:
962 bs = BracketedSequence(
963 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
964 bs.first - self.caterpillar_length,
971 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
972 # 1s where tokens should be generated. The others are kept
975 def masked_inplace_autoregression(
979 forbidden_tokens=None,
980 deterministic_synthesis=False,
982 input = input_src.to(self.readout.f.weight.device)
983 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
984 to_generate = (ar_mask.sum(0) > 0).nonzero()
985 if to_generate.min() > 0:
987 BracketedSequence(input, 0, to_generate.min(), True)
988 ) # Needed to initialize the model's cache
989 for s in range(to_generate.min(), to_generate.max() + 1):
990 output = self(BracketedSequence(input, s, 1, s == 0)).x
991 logits = output[:, s]
992 if forbidden_tokens is not None:
993 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
994 if deterministic_synthesis:
995 t_next = logits.argmax(1)
997 dist = torch.distributions.categorical.Categorical(logits=logits)
998 t_next = dist.sample()
999 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1001 input_src.copy_(input)
1003 def reset_inner_loss(self):
1004 for m in self.modules():
1005 if m is not self and hasattr(m, "reset_inner_loss"):
1006 m.reset_inner_loss()
1008 def get_inner_loss(self):
1009 l = torch.tensor([0.0], device=self.readout.f.weight.device)
1010 for m in self.modules():
1011 if m is not self and hasattr(m, "get_inner_loss"):
1012 l += m.get_inner_loss()
1015 def record_attention(self, v=True):
1016 for m in self.modules():
1017 if isinstance(m, QKVAttention):
1018 m.record_attention = v
1020 def retrieve_attention(self):
1022 for m in self.modules():
1023 if isinstance(m, QKVAttention):
1028 ######################################################################
1030 if __name__ == "__main__":
1034 import matplotlib.pyplot as plt
1035 import matplotlib.collections as mc
1037 args = argparse.Namespace(
1038 gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1041 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1043 dim_model, dim_keys, nb_heads = 512, 64, 1
1046 caterpillar = Caterpillar(
1047 dim_model=dim_model,
1049 dim_v=dim_model // nb_heads,
1051 caterpillar_length=16,
1052 caterpillar_height=32,
1053 attention_dropout=dropout,
1058 dim_model=dim_model,
1060 dim_v=dim_model // nb_heads,
1063 attention_dropout=dropout,
1067 linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1069 x = torch.randn(1, 256, dim_model)
1074 ######################################################################
1077 fig.set_figheight(6)
1080 ax = fig.add_subplot(1, 1, 1)
1082 # ax.set_xlim(-1.5, 1.5)
1083 # ax.set_ylim(-1.5, 1.5)
1085 # ax.spines.right.set_visible(False)
1086 # ax.spines.top.set_visible(False)
1089 # t = np.arange(dt, 20.0, dt)
1090 # ax.semilogx(t, np.exp(-t / 5.0))
1092 ax.set_yscale("log")
1094 ######################################################################
1096 for label, model, thickness in [
1097 ("nn.Linear", linear, 0.2),
1098 ("mygpy.QKVAttention", qkv, 1),
1099 ("mygpt.Caterpillar", caterpillar, 2),
1101 y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1103 for n, p in [("input", x)] + list(model.named_parameters()):
1104 print(f"Processing {model}.{n}")
1106 for t in range(y.size(1)):
1108 for d in torch.randperm(y.size(2))[:8]:
1109 sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1110 assert not sg.isinf().any()
1111 assert not sg.isnan().any()
1112 data.append([t, sg.sum().item()])
1114 data = torch.tensor(data)
1115 # cx, cy = data[:, 0], data[:, 1]
1116 cy = data[:, 1].sort().values
1117 cx = torch.linspace(0, 1, cy.size(0))
1119 cx, cy, label=label + "." + n, linewidth=thickness
1120 ) # , color='gray', label='Input')
1122 # ax.legend(frameon=False, loc="top right")
1124 # Put a legend to the right of the current axis
1125 box = ax.get_position()
1126 ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1127 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1129 filename = "plot.pdf"
1130 print(f"saving {filename}")
1131 fig.savefig(filename, bbox_inches="tight")
1133 # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1134 # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1139 ######################################################################
1146 caterpillar_length=7,
1147 caterpillar_height=3,
1148 attention_dropout=0.0,
1151 m.reset_inner_loss()
1152 x = torch.randn(1, 21 + 2 * 7, 4)
1153 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1154 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1155 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1156 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1157 print((y1 - y2).abs().max())
1158 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1161 vocabulary_size = 128
1162 x = torch.randint(vocabulary_size, (6, 1024))
1165 vocabulary_size=vocabulary_size,
1181 # import torchvision.models as models
1182 # from torch.profiler import profile, record_function, ProfilerActivity
1184 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1185 # with record_function("model_inference"):
1189 start_time = time.perf_counter()
1191 model(BracketedSequence(x))
1192 duration = time.perf_counter() - start_time
1196 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1197 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1199 # print("##############################################################")
1200 # y2 = torch.randn_like(y1)
1201 # for s in range(x.size(1)):
1202 # z = model(BracketedSequence(x, s, 1))
1203 # y2[:, s : s + 1] = z.slice()
1205 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1207 ######################################################################