3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
18 from torch.nn import functional as F
24 ######################################################################
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
39 # Modules able to process brackets may implement a cache that is
40 # resetted when init_cache is True
43 class BracketedSequence:
44 def __init__(self, x, first=None, nb=None, init_cache=None):
46 assert (first is None and nb is None and init_cache is None) or (
47 first is not None and nb is not None and init_cache is not None
50 self.first = 0 if first is None else first
51 self.nb = x.size(1) if nb is None else nb
52 self.init_cache = True if init_cache is None else init_cache
55 return self.x[:, self.first : self.first + self.nb]
58 return self.first == 0 and self.nb == self.x.size(1)
61 ######################################################################
64 class CacheWrapper(nn.Module):
65 def __init__(self, *f):
67 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
69 def forward(self, bs):
71 y = self.f(bs.slice())
72 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73 self.cache_y[:, bs.first : bs.first + bs.nb] = y
75 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76 assert bs.first + bs.nb <= self.cache_y.size(1)
77 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
79 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82 ##############################
85 class WithResidual(nn.Module):
86 def __init__(self, *f):
88 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
90 def forward(self, bs):
91 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94 ##############################
97 class AddPositionalEncoding(nn.Module):
98 def __init__(self, len_max):
100 self.len_max = len_max
102 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
104 def forward(self, bs):
106 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
114 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
116 self.cache_y = bs.x.new(bs.x.size())
118 self.cache_y[:, bs.first : bs.first + bs.nb] = (
119 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
128 # X is /.../xTxD A is /.../xT Y_init is /.../xD
131 def pscan_dim(A, X, Y_init, dim=-2):
133 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135 A = A.reshape(a, T, *s[dim + 1 : -1])
136 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143 Y = pscan.pscan(A, X, Y_init).reshape(s)
148 def pscan_shape(A, X, Y_init):
150 A = A.reshape(-1, s[-2])
151 X = X.reshape(-1, s[-2], s[-1])
154 Y_init = X.new_zeros(X.size(0), s[-1])
156 Y_init = Y_init.reshape(-1, s[-1])
158 Y = pscan.pscan(A, X, Y_init).reshape(s)
163 def nsum_shape(X, Y_init):
165 X = X.reshape(-1, s[-2], s[-1]) # ntd
167 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170 for k in range(X.size(1)):
172 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175 return torch.cat(result, dim=1).reshape(s)
178 ##############################
181 class DumbRec(nn.Module):
189 attention_dropout=0.0,
195 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
197 self.nb_lines = nb_lines
198 self.attention_dropout = attention_dropout
200 self.k_star = randw(nb_lines, dim_qk)
202 self.w_qw = randw(nb_heads, dim_qk, dim_model)
203 self.w_qr = randw(nb_heads, dim_qk, dim_model)
204 # self.w_k = randw(nb_heads, dim_qk, dim_model)
205 self.w_v = randw(nb_heads, dim_v, dim_model)
206 self.w_o = randw(dim_v * nb_heads, dim_model)
208 def reset_inner_loss(self):
209 self.acc_attention = 0
212 def get_inner_loss(self):
213 warnings.warn("l2 regularization", RuntimeWarning)
214 return (self.acc_attention / self.acc_nb).pow(2).sum()
215 # return torch.tensor([0], device=self.w_qw.device)
217 def forward(self, bs):
218 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221 self.rec_v = x_q.new_zeros(
222 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
224 # self.rec_k = x_q.new_zeros(
225 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
227 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
229 ######################################################################
232 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
234 warnings.warn("rotating key barrel", RuntimeWarning)
235 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236 t_barrel = torch.arange(t0, t1, device=k_star.device)
237 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
239 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
241 k_star = k_star[l_barrel, t_barrel]
243 ######################################################################
244 # Compute the recurrent state
246 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
248 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
255 ) / math.sqrt(self.w_qw.size(1))
257 aw = aw.softmax(dim=2) # nhlt
260 self.acc_attention += aw.sum(dim=(0, 1, 3))
261 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
263 aw = F.dropout(aw, self.attention_dropout, self.training)
265 A = 1 - aw.sum(dim=1) # nlt
267 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
274 V0 = self.rec_v[:, :, t0 - 1]
275 # K0 = self.rec_k[:, :, t0 - 1]
277 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
280 ######################################################################
281 # compute the readout
283 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288 # self.rec_k[:, :, t0:t1],
290 ) / math.sqrt(self.w_qr.size(1))
292 ar = ar.softmax(dim=2) # nhlt
294 ar = F.dropout(ar, self.attention_dropout, self.training)
299 self.rec_v[:, :, t0:t1],
302 self.cache_y[:, t0:t1] = y @ self.w_o
304 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307 ##############################
310 class KVRec(nn.Module):
318 attention_dropout=0.0,
324 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
326 self.nb_lines = nb_lines
327 self.attention_dropout = attention_dropout
329 self.k_star = randw(nb_lines, dim_qk)
331 self.w_qw = randw(nb_heads, dim_qk, dim_model)
332 self.w_qr = randw(nb_heads, dim_qk, dim_model)
333 self.w_k = randw(nb_heads, dim_qk, dim_model)
334 self.w_v = randw(nb_heads, dim_v, dim_model)
335 self.w_o = randw(dim_v * nb_heads, dim_model)
337 def reset_inner_loss(self):
338 self.acc_attention = 0
341 def get_inner_loss(self):
342 warnings.warn("l2 regularization", RuntimeWarning)
343 return (self.acc_attention / self.acc_nb).pow(2).sum()
344 # return torch.tensor([0], device=self.w_qw.device)
345 # warnings.warn("side regularization", RuntimeWarning)
347 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
349 # return torch.tensor([0], device=self.w_qw.device)
351 def forward(self, bs):
352 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355 self.rec_v = x_q.new_zeros(
356 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
358 self.rec_k = x_q.new_zeros(
359 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
361 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
363 ######################################################################
366 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
368 warnings.warn("rotating key barrel", RuntimeWarning)
369 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370 t_barrel = torch.arange(t0, t1, device=k_star.device)
371 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
373 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
375 k_star = k_star[l_barrel, t_barrel]
377 ######################################################################
378 # Compute the recurrent state
380 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
382 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
389 ) / math.sqrt(self.w_qw.size(1))
391 aw = aw.softmax(dim=2) # nhlt
394 # We want all the memory lines to be used similarly
395 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
396 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
398 aw = F.dropout(aw, self.attention_dropout, self.training)
400 A = 1 - aw.sum(dim=1) # nlt
402 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
409 V0 = self.rec_v[:, :, t0 - 1]
410 K0 = self.rec_k[:, :, t0 - 1]
412 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
415 ######################################################################
416 # compute the readout
418 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
423 self.rec_k[:, :, t0:t1],
424 ) / math.sqrt(self.w_qr.size(1))
426 ar = ar.softmax(dim=2) # nhlt
428 ar = F.dropout(ar, self.attention_dropout, self.training)
433 self.rec_v[:, :, t0:t1],
436 self.cache_y[:, t0:t1] = y @ self.w_o
438 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441 ##############################
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
449 def moving_window(x, dim, win_dim, win_size):
450 size, stride = x.size(), x.stride()
451 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452 size = size[:win_dim] + (win_size,) + size[win_dim:]
453 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
455 return x.as_strided(size=size, stride=stride)
458 ##############################
461 class Caterpillar(nn.Module):
470 attention_dropout=0.0,
475 warnings.warn("Caterpillar", RuntimeWarning)
478 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
480 self.caterpillar_length = caterpillar_length
481 self.caterpillar_height = caterpillar_height
482 self.attention_dropout = attention_dropout
484 warnings.warn("flash back", RuntimeWarning)
485 self.proba_flashback = 1e-2
487 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
488 self.b_G = nn.Parameter(
490 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
494 self.w_K = randw(nb_heads, dim_qk, dim_model)
495 self.w_V = randw(nb_heads, dim_v, dim_model)
496 self.w_Q = randw(nb_heads, dim_qk, dim_model)
497 self.w_O = randw(dim_v * nb_heads, dim_model)
499 self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
500 self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
502 def reset_inner_loss(self):
503 self.acc_attention = 0
506 def get_inner_loss(self):
507 # warnings.warn("l2 regularization", RuntimeWarning)
508 # return (self.acc_attention / self.acc_nb).pow(2).sum()
509 return torch.tensor([0], device=self.w_Q.device)
511 def forward(self, bs):
512 # Dimensions to make the source a bit clearer, that's needed
514 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
519 DV = self.w_V.size(1)
520 DK = self.w_K.size(1)
521 DM = self.w_O.size(1)
522 CH = self.caterpillar_height
523 CL = self.caterpillar_length
526 t0 >= CL and (t1 - t0) % CL == 0
527 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
529 # We cache values to deal efficiently with auto-regression
532 self.rec_V = X.new_zeros(N, CH, T, DV)
533 self.rec_K = X.new_zeros(N, CH, T, DK)
534 # We start the recurrent sequences with optimizable
535 # initial values. No idea if it helps.
536 self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
537 self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
539 self.cache_Y = X.new_zeros(N, T, DM)
541 ######################################################################
542 # Compute the recurrent state
544 # This is the Gating sequence that modulates the storing of
545 # the new key and value in the CH pairs of the current
546 # stack. The CH gating values are independent, which means
547 # that the current K/V could be stored in multiple pairs of the
548 # recurrent state, or not at all.
551 torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
554 # That bas a bad idea
555 # G = F.dropout(G, self.attention_dropout, self.training)
557 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
558 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
560 # We prepare the arguments for the parallel scan
563 gated_V = torch.einsum("nhet,nhtd->netd", G, V)
564 gated_K = torch.einsum("nhet,nhtd->netd", G, K)
566 init_rec_V = self.rec_V[:, :, t0 - CL : t0]
567 init_rec_K = self.rec_K[:, :, t0 - CL : t0]
569 # Here there is a trick: Since the stack at time t is computed
570 # by updating that at time t-L, the parallel scan operates
571 # with a period of L. To do so we split the time indexing in
572 # two axes, the second of size CL, and run the parallel scan
573 # using the other as the sequence index.
575 A = A.unflatten(2, (-1, CL))
576 gated_V = gated_V.unflatten(2, (-1, CL))
577 gated_K = gated_K.unflatten(2, (-1, CL))
579 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
580 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
582 # Put back the sequence index
584 self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
585 self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
587 if self.training and self.proba_flashback > 0.0:
588 # This piece of code makes the assumption that there is
589 # nothing informative before t0, otherwise we'd have to
590 # implement a cache for V and K too. This should not be
591 # too much of a problem since this is used only during
592 # train, where full sequence are available
594 n = torch.arange(N, device=X.device)[:, None, None, None]
595 t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
596 dv = torch.arange(DV, device=X.device)[None, None, None, :]
597 dk = torch.arange(DK, device=X.device)[None, None, None, :]
600 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
603 src_time = t - u - t0
604 src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
607 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
610 self.rec_V[:, :, t0:t1] = (
611 mask * V[n, src_head, src_time, dv]
612 + (1 - mask) * self.rec_V[:, :, t0:t1]
615 self.rec_K[:, :, t0:t1] = (
616 mask * K[n, src_head, src_time, dk]
617 + (1 - mask) * self.rec_K[:, :, t0:t1]
620 ######################################################################
621 # compute the readout
623 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
625 # We build tensors NxHxTxFxL where N is the sample index, H
626 # the head, T the time, F the row in the caterpillar, and L
627 # the column in the caterpillar
629 windowed_V = moving_window(
630 self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
633 windowed_K = moving_window(
634 self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
637 # We have an attention score for each of the CHxCL values
645 # softmax can operate only on one dimension, hence the
648 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
650 ar = F.dropout(ar, self.attention_dropout, self.training)
652 # Compute the output for each head, flatten to concatenate
660 # Compute the final output
662 self.cache_Y[:, t0:t1] = Y @ self.w_O
664 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
667 ##############################
670 class QKVAttention(nn.Module):
678 attention_dropout=0.0,
683 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
686 self.attention_dropout = attention_dropout
687 self.record_attention = False
689 self.w_q = randw(nb_heads, dim_qk, dim_model)
690 self.w_k = randw(nb_heads, dim_qk, dim_model)
691 self.w_v = randw(nb_heads, dim_v, dim_model)
692 self.w_o = randw(dim_v * nb_heads, dim_model)
694 def forward(self, bs):
698 self.causal or bs.complete()
699 ), "Partial evaluation is only possible for causal models"
702 self.cache_k = x_q.new_zeros(
703 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
705 self.cache_v = x_q.new_zeros(
706 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
708 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
710 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
712 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
713 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
715 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
716 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
720 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
721 ) / math.sqrt(self.w_q.size(1))
725 self.cache_attzero = (
726 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
727 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
731 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
738 if self.record_attention:
741 a = F.dropout(a, self.attention_dropout, self.training)
744 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
747 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
749 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
752 ##############################
755 class MyGPT(nn.Module):
765 caterpillar_height=None,
770 attention_layer="kvrec",
774 assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
776 if attention_layer == "caterpillar":
777 assert nb_lines % caterpillar_height == 0
778 self.caterpillar_length = nb_lines // caterpillar_height
779 self.caterpillar_height = caterpillar_height
781 self.caterpillar_length = -1
782 self.caterpillar_height = -1
784 assert dim_model % nb_heads == 0
786 self.embedding = nn.Sequential(
787 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
788 AddPositionalEncoding(len_max),
794 if attention_layer == "mha":
798 dim_v=dim_model // nb_heads,
801 attention_dropout=dropout,
803 elif attention_layer == "dumbrec":
810 attention_dropout=dropout,
812 elif attention_layer == "kvrec":
819 attention_dropout=dropout,
821 elif attention_layer == "caterpillar":
827 caterpillar_length=self.caterpillar_length,
828 caterpillar_height=self.caterpillar_height,
829 attention_dropout=dropout,
832 raise ValueError(f"Unknown attention type {attention_layer}.")
834 for b in range(nb_blocks):
837 CacheWrapper(nn.LayerNorm((dim_model,))),
842 nn.LayerNorm((dim_model,)),
843 nn.Linear(in_features=dim_model, out_features=dim_hidden),
845 nn.Linear(in_features=dim_hidden, out_features=dim_model),
851 self.trunk = nn.Sequential(*trunk_blocks)
853 self.readout = CacheWrapper(
854 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
857 with torch.no_grad():
858 for m in self.modules():
859 if isinstance(m, nn.Embedding):
860 m.weight.normal_(mean=0, std=2e-2)
861 elif isinstance(m, nn.LayerNorm):
865 self.reset_inner_loss()
867 def forward(self, bs):
868 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
870 # To make the code simpler in the Caterpillar layer, we pad
871 # here. It's unclear if/how much it hurts computationaly by
872 # increasing the sequence length for the other layers
874 if self.caterpillar_length > 0:
876 if bs.nb % self.caterpillar_length > 0:
877 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
879 bs = BracketedSequence(
880 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
881 bs.first + self.caterpillar_length,
886 bs = self.embedding(bs)
888 bs = self.readout(bs)
890 if self.caterpillar_length > 0:
891 bs = BracketedSequence(
892 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
893 bs.first - self.caterpillar_length,
900 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
901 # 1s where tokens should be generated. The others are kept
904 def masked_inplace_autoregression(
908 forbidden_tokens=None,
909 deterministic_synthesis=False,
911 input = input_src.to(self.readout.f.weight.device)
912 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
913 to_generate = (ar_mask.sum(0) > 0).nonzero()
914 if to_generate.min() > 0:
916 BracketedSequence(input, 0, to_generate.min(), True)
917 ) # Needed to initialize the model's cache
918 for s in range(to_generate.min(), to_generate.max() + 1):
919 output = self(BracketedSequence(input, s, 1, s == 0)).x
920 logits = output[:, s]
921 if forbidden_tokens is not None:
922 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
923 if deterministic_synthesis:
924 t_next = logits.argmax(1)
926 dist = torch.distributions.categorical.Categorical(logits=logits)
927 t_next = dist.sample()
928 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
930 input_src.copy_(input)
932 def reset_inner_loss(self):
933 for m in self.modules():
934 if m is not self and hasattr(m, "reset_inner_loss"):
937 def get_inner_loss(self):
938 l = torch.tensor([0.0], device=self.readout.f.weight.device)
939 for m in self.modules():
940 if m is not self and hasattr(m, "get_inner_loss"):
941 l += m.get_inner_loss()
944 def record_attention(self, v=True):
945 for m in self.modules():
946 if isinstance(m, QKVAttention):
947 m.record_attention = v
949 def retrieve_attention(self):
951 for m in self.modules():
952 if isinstance(m, QKVAttention):
957 ######################################################################
959 if __name__ == "__main__":
960 print("Basic check.")
967 caterpillar_length=7,
968 caterpillar_height=3,
969 attention_dropout=0.0,
973 x = torch.randn(1, 21 + 2 * 7, 4)
974 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
975 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
976 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
977 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
978 print((y1 - y2).abs().max())
979 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
982 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
984 vocabulary_size = 128
985 x = torch.randint(vocabulary_size, (6, 1024))
988 vocabulary_size=vocabulary_size,
1004 # import torchvision.models as models
1005 # from torch.profiler import profile, record_function, ProfilerActivity
1007 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1008 # with record_function("model_inference"):
1012 start_time = time.perf_counter()
1014 model(BracketedSequence(x))
1015 duration = time.perf_counter() - start_time
1019 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1020 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1022 # print("##############################################################")
1023 # y2 = torch.randn_like(y1)
1024 # for s in range(x.size(1)):
1025 # z = model(BracketedSequence(x, s, 1))
1026 # y2[:, s : s + 1] = z.slice()
1028 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1030 ######################################################################