3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
26 ######################################################################
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
45 class BracketedSequence:
46 def __init__(self, x, first=None, nb=None, init_cache=None):
48 assert (first is None and nb is None and init_cache is None) or (
49 first is not None and nb is not None and init_cache is not None
52 self.first = 0 if first is None else first
53 self.nb = x.size(1) if nb is None else nb
54 self.init_cache = True if init_cache is None else init_cache
57 return self.x[:, self.first : self.first + self.nb]
60 return self.first == 0 and self.nb == self.x.size(1)
63 ######################################################################
66 class CacheWrapper(nn.Module):
67 def __init__(self, *f):
69 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
71 def forward(self, bs):
73 y = self.f(bs.slice())
74 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75 self.cache_y[:, bs.first : bs.first + bs.nb] = y
77 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78 assert bs.first + bs.nb <= self.cache_y.size(1)
79 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
81 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84 ##############################
87 class WithResidual(nn.Module):
88 def __init__(self, *f):
90 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
92 def forward(self, bs):
93 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
96 ##############################
99 class AddPositionalEncoding(nn.Module):
100 def __init__(self, len_max):
102 self.len_max = len_max
104 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
106 def forward(self, bs):
108 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
111 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
116 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
118 self.cache_y = bs.x.new(bs.x.size())
120 self.cache_y[:, bs.first : bs.first + bs.nb] = (
121 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
124 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
130 # X is /.../xTxD A is /.../xT Y_init is /.../xD
133 def pscan_dim(A, X, Y_init, dim=-2):
135 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
137 A = A.reshape(a, T, *s[dim + 1 : -1])
138 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
141 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
143 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
145 Y = pscan.pscan(A, X, Y_init).reshape(s)
150 def pscan_shape(A, X, Y_init):
152 A = A.reshape(-1, s[-2])
153 X = X.reshape(-1, s[-2], s[-1])
156 Y_init = X.new_zeros(X.size(0), s[-1])
158 Y_init = Y_init.reshape(-1, s[-1])
160 Y = pscan.pscan(A, X, Y_init).reshape(s)
165 def nsum_shape(X, Y_init):
167 X = X.reshape(-1, s[-2], s[-1]) # ntd
169 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
172 for k in range(X.size(1)):
174 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
177 return torch.cat(result, dim=1).reshape(s)
180 ##############################
183 class DumbRec(nn.Module):
191 attention_dropout=0.0,
197 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
199 self.nb_lines = nb_lines
200 self.attention_dropout = attention_dropout
202 self.k_star = randw(nb_lines, dim_qk)
204 self.w_qw = randw(nb_heads, dim_qk, dim_model)
205 self.w_qr = randw(nb_heads, dim_qk, dim_model)
206 # self.w_k = randw(nb_heads, dim_qk, dim_model)
207 self.w_v = randw(nb_heads, dim_v, dim_model)
208 self.w_o = randw(dim_v * nb_heads, dim_model)
210 def reset_inner_loss(self):
211 self.acc_attention = 0
214 def get_inner_loss(self):
215 warnings.warn("l2 regularization", RuntimeWarning)
216 return (self.acc_attention / self.acc_nb).pow(2).sum()
217 # return torch.tensor([0], device=self.w_qw.device)
219 def forward(self, bs):
220 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
223 self.rec_v = x_q.new_zeros(
224 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
226 # self.rec_k = x_q.new_zeros(
227 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
229 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
231 ######################################################################
234 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
236 warnings.warn("rotating key barrel", RuntimeWarning)
237 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238 t_barrel = torch.arange(t0, t1, device=k_star.device)
239 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
241 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
243 k_star = k_star[l_barrel, t_barrel]
245 ######################################################################
246 # Compute the recurrent state
248 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
250 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
257 ) / math.sqrt(self.w_qw.size(1))
259 aw = aw.softmax(dim=2) # nhlt
262 self.acc_attention += aw.sum(dim=(0, 1, 3))
263 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
265 aw = F.dropout(aw, self.attention_dropout, self.training)
267 A = 1 - aw.sum(dim=1) # nlt
269 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
276 V0 = self.rec_v[:, :, t0 - 1]
277 # K0 = self.rec_k[:, :, t0 - 1]
279 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
282 ######################################################################
283 # compute the readout
285 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
290 # self.rec_k[:, :, t0:t1],
292 ) / math.sqrt(self.w_qr.size(1))
294 ar = ar.softmax(dim=2) # nhlt
296 ar = F.dropout(ar, self.attention_dropout, self.training)
301 self.rec_v[:, :, t0:t1],
304 self.cache_y[:, t0:t1] = y @ self.w_o
306 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
309 ##############################
312 class KVRec(nn.Module):
320 attention_dropout=0.0,
326 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
328 self.nb_lines = nb_lines
329 self.attention_dropout = attention_dropout
331 self.k_star = randw(nb_lines, dim_qk)
333 self.w_qw = randw(nb_heads, dim_qk, dim_model)
334 self.w_qr = randw(nb_heads, dim_qk, dim_model)
335 self.w_k = randw(nb_heads, dim_qk, dim_model)
336 self.w_v = randw(nb_heads, dim_v, dim_model)
337 self.w_o = randw(dim_v * nb_heads, dim_model)
339 def reset_inner_loss(self):
340 self.acc_attention = 0
343 def get_inner_loss(self):
344 warnings.warn("l2 regularization", RuntimeWarning)
345 return (self.acc_attention / self.acc_nb).pow(2).sum()
346 # return torch.tensor([0], device=self.w_qw.device)
347 # warnings.warn("side regularization", RuntimeWarning)
349 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
351 # return torch.tensor([0], device=self.w_qw.device)
353 def forward(self, bs):
354 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
357 self.rec_v = x_q.new_zeros(
358 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
360 self.rec_k = x_q.new_zeros(
361 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
363 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
365 ######################################################################
368 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
370 warnings.warn("rotating key barrel", RuntimeWarning)
371 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372 t_barrel = torch.arange(t0, t1, device=k_star.device)
373 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
375 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
377 k_star = k_star[l_barrel, t_barrel]
379 ######################################################################
380 # Compute the recurrent state
382 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
384 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
391 ) / math.sqrt(self.w_qw.size(1))
393 aw = aw.softmax(dim=2) # nhlt
396 # We want all the memory lines to be used similarly
397 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
398 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
400 aw = F.dropout(aw, self.attention_dropout, self.training)
402 A = 1 - aw.sum(dim=1) # nlt
404 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
411 V0 = self.rec_v[:, :, t0 - 1]
412 K0 = self.rec_k[:, :, t0 - 1]
414 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
417 ######################################################################
418 # compute the readout
420 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
425 self.rec_k[:, :, t0:t1],
426 ) / math.sqrt(self.w_qr.size(1))
428 ar = ar.softmax(dim=2) # nhlt
430 ar = F.dropout(ar, self.attention_dropout, self.training)
435 self.rec_v[:, :, t0:t1],
438 self.cache_y[:, t0:t1] = y @ self.w_o
440 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
443 ##############################
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
451 def moving_window(x, dim, win_dim, win_size):
452 size, stride = x.size(), x.stride()
453 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454 size = size[:win_dim] + (win_size,) + size[win_dim:]
455 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
457 return x.as_strided(size=size, stride=stride)
460 ##############################
463 class Caterpillar(nn.Module):
472 attention_dropout=0.0,
477 warnings.warn("Caterpillar", RuntimeWarning)
480 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
482 self.caterpillar_length = caterpillar_length
483 self.caterpillar_height = caterpillar_height
484 self.attention_dropout = attention_dropout
486 self.proba_gate_dropout = 0.0
488 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
489 self.b_G = nn.Parameter(
491 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
495 self.w_K = randw(nb_heads, dim_qk, dim_model)
496 self.w_V = randw(nb_heads, dim_v, dim_model)
497 self.w_Q = randw(nb_heads, dim_qk, dim_model)
498 self.w_O = randw(dim_v * nb_heads, dim_model)
500 self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
501 self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
503 def reset_inner_loss(self):
504 self.acc_attention = 0
507 def get_inner_loss(self):
508 # warnings.warn("l2 regularization", RuntimeWarning)
509 # return (self.acc_attention / self.acc_nb).pow(2).sum()
510 return torch.tensor([0], device=self.w_Q.device)
512 def forward(self, bs):
513 # Dimensions to make the source a bit clearer, that's needed
515 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
520 DV = self.w_V.size(1)
521 DK = self.w_K.size(1)
522 DM = self.w_O.size(1)
523 CH = self.caterpillar_height
524 CL = self.caterpillar_length
527 t0 >= CL and (t1 - t0) % CL == 0
528 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
530 # We cache values to deal efficiently with auto-regression
533 self.rec_V = X.new_zeros(N, CH, T, DV)
534 self.rec_K = X.new_zeros(N, CH, T, DK)
535 # We start the recurrent sequences with optimizable
536 # initial values. No idea if it helps.
537 self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
538 self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
540 self.cache_Y = X.new_zeros(N, T, DM)
542 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
543 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
545 ######################################################################
546 # Compute the recurrent state
548 # This is the Gating sequence that modulates the storing of
549 # the new key and value in the CH pairs of the current
550 # stack. There are CH independent gating values, which means
551 # that the current K/V may be stored in multiple pairs of the
552 # recurrent state, or not at all.
555 torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
558 # Clip the gating to avoid values greater than 1 when several
559 # heads hit the same row
561 G = G / G.sum(1, keepdim=True).clamp(min=1)
563 # We prepare the arguments for the parallel scan
566 gated_V = torch.einsum("nhet,nhtd->netd", G, V)
567 gated_K = torch.einsum("nhet,nhtd->netd", G, K)
569 # We start from cached values, which matters in inference
571 init_rec_V = self.rec_V[:, :, t0 - CL : t0]
572 init_rec_K = self.rec_K[:, :, t0 - CL : t0]
574 ######################################################################
576 if self.training and self.proba_gate_dropout > 0.0:
577 # This is a better implementation of "flashbacks". A is
578 # NxExT where e is the caterpillar's row.
580 warnings.warn("gate dropout", RuntimeWarning)
583 #################################################################
586 # Here there is a trick: Since the stack at position t is
587 # computed by updating that at position t-CL, the parallel
588 # scan operates with a period of CL. To do so we split the
589 # sequence indexing in two axes, the second of size CL, and
590 # run the parallel scan using the first as the sequence index.
592 A = A.unflatten(2, (-1, CL))
593 gated_V = gated_V.unflatten(2, (-1, CL))
594 gated_K = gated_K.unflatten(2, (-1, CL))
596 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
597 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
599 self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
600 self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
602 ######################################################################
603 # compute the readout
605 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
607 # We build tensors NxHxTxFxL where N is the sample index, H
608 # the head, T the time, F the row in the caterpillar, and L
609 # the column in the caterpillar
611 windowed_V = moving_window(
612 self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
615 windowed_K = moving_window(
616 self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
619 # We have an attention score for each of the CHxCL values
627 # softmax can operate only on one dimension, hence the
630 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
632 ar = F.dropout(ar, self.attention_dropout, self.training)
634 # Compute the output for each head, flatten to concatenate
642 # Compute the final output
644 self.cache_Y[:, t0:t1] = Y @ self.w_O
646 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
649 ##############################
652 class QKVAttention(nn.Module):
660 attention_dropout=0.0,
665 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
668 self.attention_dropout = attention_dropout
669 self.record_attention = False
671 self.w_q = randw(nb_heads, dim_qk, dim_model)
672 self.w_k = randw(nb_heads, dim_qk, dim_model)
673 self.w_v = randw(nb_heads, dim_v, dim_model)
674 self.w_o = randw(dim_v * nb_heads, dim_model)
676 def forward(self, bs):
680 self.causal or bs.complete()
681 ), "Partial evaluation is only possible for causal models"
684 self.cache_k = x_q.new_zeros(
685 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
687 self.cache_v = x_q.new_zeros(
688 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
690 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
692 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
694 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
695 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
697 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
698 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
702 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
703 ) / math.sqrt(self.w_q.size(1))
707 self.cache_attzero = (
708 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
709 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
713 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
720 if self.record_attention:
723 a = F.dropout(a, self.attention_dropout, self.training)
726 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
729 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
731 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
734 ##############################
737 class MyGPT(nn.Module):
747 caterpillar_height=None,
751 attention_layer="kvrec",
755 assert attention_layer in {
760 }, f"Unknown attention operator {attention_layer}."
762 if attention_layer == "caterpillar":
763 assert nb_lines % caterpillar_height == 0
764 self.caterpillar_length = nb_lines // caterpillar_height
765 self.caterpillar_height = caterpillar_height
767 self.caterpillar_length = -1
768 self.caterpillar_height = -1
770 assert dim_model % nb_heads == 0
772 self.embedding = nn.Sequential(
773 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
774 AddPositionalEncoding(len_max),
780 if attention_layer == "mha":
784 dim_v=dim_model // nb_heads,
787 attention_dropout=dropout,
789 elif attention_layer == "dumbrec":
793 dim_v=dim_model // nb_heads,
796 attention_dropout=dropout,
798 elif attention_layer == "kvrec":
802 dim_v=dim_model // nb_heads,
805 attention_dropout=dropout,
807 elif attention_layer == "caterpillar":
811 dim_v=dim_model // nb_heads,
813 caterpillar_length=self.caterpillar_length,
814 caterpillar_height=self.caterpillar_height,
815 attention_dropout=dropout,
818 raise ValueError(f"Unknown attention type {attention_layer}.")
820 for b in range(nb_blocks):
823 CacheWrapper(nn.LayerNorm((dim_model,))),
828 nn.LayerNorm((dim_model,)),
829 nn.Linear(in_features=dim_model, out_features=dim_hidden),
831 nn.Linear(in_features=dim_hidden, out_features=dim_model),
837 self.trunk = nn.Sequential(*trunk_blocks)
839 self.readout = CacheWrapper(
840 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
843 with torch.no_grad():
844 for m in self.modules():
845 if isinstance(m, nn.Embedding):
846 m.weight.normal_(mean=0, std=2e-2)
847 elif isinstance(m, nn.LayerNorm):
851 self.reset_inner_loss()
853 def forward(self, bs):
854 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
856 # To make the code simpler in the Caterpillar layer, we pad
857 # here. It's unclear if/how much it hurts computationaly by
858 # increasing the sequence length for the other layers
860 if self.caterpillar_length > 0:
862 if bs.nb % self.caterpillar_length > 0:
863 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
865 bs = BracketedSequence(
866 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
867 bs.first + self.caterpillar_length,
872 bs = self.embedding(bs)
874 bs = self.readout(bs)
876 if self.caterpillar_length > 0:
877 bs = BracketedSequence(
878 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
879 bs.first - self.caterpillar_length,
886 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
887 # 1s where tokens should be generated. The others are kept
890 def masked_inplace_autoregression(
894 forbidden_tokens=None,
895 deterministic_synthesis=False,
897 input = input_src.to(self.readout.f.weight.device)
898 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
899 to_generate = (ar_mask.sum(0) > 0).nonzero()
900 if to_generate.min() > 0:
902 BracketedSequence(input, 0, to_generate.min(), True)
903 ) # Needed to initialize the model's cache
904 for s in range(to_generate.min(), to_generate.max() + 1):
905 output = self(BracketedSequence(input, s, 1, s == 0)).x
906 logits = output[:, s]
907 if forbidden_tokens is not None:
908 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
909 if deterministic_synthesis:
910 t_next = logits.argmax(1)
912 dist = torch.distributions.categorical.Categorical(logits=logits)
913 t_next = dist.sample()
914 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
916 input_src.copy_(input)
918 def reset_inner_loss(self):
919 for m in self.modules():
920 if m is not self and hasattr(m, "reset_inner_loss"):
923 def get_inner_loss(self):
924 l = torch.tensor([0.0], device=self.readout.f.weight.device)
925 for m in self.modules():
926 if m is not self and hasattr(m, "get_inner_loss"):
927 l += m.get_inner_loss()
930 def record_attention(self, v=True):
931 for m in self.modules():
932 if isinstance(m, QKVAttention):
933 m.record_attention = v
935 def retrieve_attention(self):
937 for m in self.modules():
938 if isinstance(m, QKVAttention):
943 ######################################################################
945 if __name__ == "__main__":
946 print("Basic check.")
953 caterpillar_length=7,
954 caterpillar_height=3,
955 attention_dropout=0.0,
959 x = torch.randn(1, 21 + 2 * 7, 4)
960 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
961 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
962 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
963 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
964 print((y1 - y2).abs().max())
965 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
968 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
970 vocabulary_size = 128
971 x = torch.randint(vocabulary_size, (6, 1024))
974 vocabulary_size=vocabulary_size,
990 # import torchvision.models as models
991 # from torch.profiler import profile, record_function, ProfilerActivity
993 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
994 # with record_function("model_inference"):
998 start_time = time.perf_counter()
1000 model(BracketedSequence(x))
1001 duration = time.perf_counter() - start_time
1005 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1006 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1008 # print("##############################################################")
1009 # y2 = torch.randn_like(y1)
1010 # for s in range(x.size(1)):
1011 # z = model(BracketedSequence(x, s, 1))
1012 # y2[:, s : s + 1] = z.slice()
1014 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1016 ######################################################################