3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
26 ######################################################################
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
45 class BracketedSequence:
46 def __init__(self, x, first=None, nb=None, init_cache=None):
48 assert (first is None and nb is None and init_cache is None) or (
49 first is not None and nb is not None and init_cache is not None
52 self.first = 0 if first is None else first
53 self.nb = x.size(1) if nb is None else nb
54 self.init_cache = True if init_cache is None else init_cache
57 return self.x[:, self.first : self.first + self.nb]
60 return self.first == 0 and self.nb == self.x.size(1)
63 ######################################################################
66 class CacheWrapper(nn.Module):
67 def __init__(self, *f):
69 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
71 def forward(self, bs):
73 y = self.f(bs.slice())
74 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75 self.cache_y[:, bs.first : bs.first + bs.nb] = y
77 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78 assert bs.first + bs.nb <= self.cache_y.size(1)
79 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
81 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84 ##############################
87 class WithResidual(nn.Module):
88 def __init__(self, *f):
90 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
92 def forward(self, bs):
93 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
96 ##############################
99 class AddPositionalEncoding(nn.Module):
100 def __init__(self, len_max):
102 self.len_max = len_max
104 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
106 def forward(self, bs):
108 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
111 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
116 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
118 self.cache_y = bs.x.new(bs.x.size())
120 self.cache_y[:, bs.first : bs.first + bs.nb] = (
121 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
124 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
129 # X is /.../xTxD A is /.../xT Y_init is /.../xD
132 def pscan_dim(A, X, Y_init, dim=-2):
134 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136 A = A.reshape(a, T, *s[dim + 1 : -1])
137 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
140 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144 Y = pscan.pscan(A, X, Y_init).reshape(s)
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150 with torch.no_grad():
152 for t in range(X.size(dim) - 1, 0, -1):
153 delta = (grad_Y[t] - s_A) / A[t].grad
154 s_A += A[t].grad * delta
156 delta = (grad_Y[t] - s_X) / X[t].grad
157 s_X += X[t].grad * delta
161 def pscan_shape(A, X, Y_init):
163 A = A.reshape(-1, s[-2])
164 X = X.reshape(-1, s[-2], s[-1])
167 Y_init = X.new_zeros(X.size(0), s[-1])
169 Y_init = Y_init.reshape(-1, s[-1])
171 Y = pscan.pscan(A, X, Y_init).reshape(s)
176 def nsum_shape(X, Y_init):
178 X = X.reshape(-1, s[-2], s[-1]) # ntd
180 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
183 for k in range(X.size(1)):
185 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
188 return torch.cat(result, dim=1).reshape(s)
191 ##############################
194 class DumbRec(nn.Module):
202 attention_dropout=0.0,
210 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
212 self.nb_lines = nb_lines
213 self.attention_dropout = attention_dropout
215 self.k_star = randw(nb_lines, dim_qk)
217 self.w_qw = randw(nb_heads, dim_qk, dim_model)
218 self.w_qr = randw(nb_heads, dim_qk, dim_model)
219 # self.w_k = randw(nb_heads, dim_qk, dim_model)
220 self.w_v = randw(nb_heads, dim_v, dim_model)
221 self.w_o = randw(dim_v * nb_heads, dim_model)
223 def reset_inner_loss(self):
224 self.acc_attention = 0
227 def get_inner_loss(self):
228 warnings.warn("l2 regularization", RuntimeWarning)
229 return (self.acc_attention / self.acc_nb).pow(2).sum()
230 # return torch.tensor([0], device=self.w_qw.device)
232 def forward(self, bs):
233 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
236 self.rec_v = x_q.new_zeros(
237 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
239 # self.rec_k = x_q.new_zeros(
240 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
242 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
244 ######################################################################
247 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
249 warnings.warn("rotating key barrel", RuntimeWarning)
250 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251 t_barrel = torch.arange(t0, t1, device=k_star.device)
252 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
254 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
256 k_star = k_star[l_barrel, t_barrel]
258 ######################################################################
259 # Compute the recurrent state
261 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
263 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
270 ) / math.sqrt(self.w_qw.size(1))
272 aw = aw.softmax(dim=2) # nhlt
275 self.acc_attention += aw.sum(dim=(0, 1, 3))
276 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
278 aw = F.dropout(aw, self.attention_dropout, self.training)
280 A = 1 - aw.sum(dim=1) # nlt
282 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
289 V0 = self.rec_v[:, :, t0 - 1]
290 # K0 = self.rec_k[:, :, t0 - 1]
292 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
295 ######################################################################
296 # compute the readout
298 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
303 # self.rec_k[:, :, t0:t1],
305 ) / math.sqrt(self.w_qr.size(1))
307 ar = ar.softmax(dim=2) # nhlt
309 ar = F.dropout(ar, self.attention_dropout, self.training)
314 self.rec_v[:, :, t0:t1],
317 self.cache_y[:, t0:t1] = y @ self.w_o
319 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
322 ##############################
325 class KVRec(nn.Module):
333 attention_dropout=0.0,
341 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
343 self.nb_lines = nb_lines
344 self.attention_dropout = attention_dropout
346 self.k_star = randw(nb_lines, dim_qk)
348 self.w_qw = randw(nb_heads, dim_qk, dim_model)
349 self.w_qr = randw(nb_heads, dim_qk, dim_model)
350 self.w_k = randw(nb_heads, dim_qk, dim_model)
351 self.w_v = randw(nb_heads, dim_v, dim_model)
352 self.w_o = randw(dim_v * nb_heads, dim_model)
354 def reset_inner_loss(self):
355 self.acc_attention = 0
358 def get_inner_loss(self):
359 warnings.warn("l2 regularization", RuntimeWarning)
360 return (self.acc_attention / self.acc_nb).pow(2).sum()
361 # return torch.tensor([0], device=self.w_qw.device)
362 # warnings.warn("side regularization", RuntimeWarning)
364 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
366 # return torch.tensor([0], device=self.w_qw.device)
368 def forward(self, bs):
369 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
372 self.rec_v = x_q.new_zeros(
373 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
375 self.rec_k = x_q.new_zeros(
376 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
378 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
380 ######################################################################
383 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
385 warnings.warn("rotating key barrel", RuntimeWarning)
386 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387 t_barrel = torch.arange(t0, t1, device=k_star.device)
388 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
390 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
392 k_star = k_star[l_barrel, t_barrel]
394 ######################################################################
395 # Compute the recurrent state
397 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
399 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
406 ) / math.sqrt(self.w_qw.size(1))
408 aw = aw.softmax(dim=2) # nhlt
411 # We want all the memory lines to be used similarly
412 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
413 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
415 aw = F.dropout(aw, self.attention_dropout, self.training)
417 A = 1 - aw.sum(dim=1) # nlt
419 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
426 V0 = self.rec_v[:, :, t0 - 1]
427 K0 = self.rec_k[:, :, t0 - 1]
429 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
432 ######################################################################
433 # compute the readout
435 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
440 self.rec_k[:, :, t0:t1],
441 ) / math.sqrt(self.w_qr.size(1))
443 ar = ar.softmax(dim=2) # nhlt
445 ar = F.dropout(ar, self.attention_dropout, self.training)
450 self.rec_v[:, :, t0:t1],
453 self.cache_y[:, t0:t1] = y @ self.w_o
455 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
458 ##############################
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
466 def moving_window(x, dim, win_dim, win_size):
467 size, stride = x.size(), x.stride()
468 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469 size = size[:win_dim] + (win_size,) + size[win_dim:]
470 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
472 return x.as_strided(size=size, stride=stride)
475 ##############################
478 class Caterpillar(nn.Module):
487 attention_dropout=0.0,
494 warnings.warn("Caterpillar", RuntimeWarning)
496 def randw(*d, amplitude=None):
497 if amplitude is None:
498 amplitude = 1 / math.sqrt(d[-1])
499 return nn.Parameter(amplitude * torch.randn(*d))
501 self.caterpillar_length = caterpillar_length
502 self.caterpillar_height = caterpillar_height
503 self.attention_dropout = attention_dropout
505 self.gate_dropout_proba = args.gate_dropout_proba
506 self.gate_dropout_sync = args.gate_dropout_sync
507 self.gate_dropout_replace = args.gate_dropout_replace
509 ######################################################################
511 default_bg = -math.log(caterpillar_height - 1)
512 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
513 self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
515 self.w_K = randw(nb_heads, dim_qk, dim_model)
516 self.w_V = randw(nb_heads, dim_v, dim_model)
517 self.w_Q = randw(nb_heads, dim_qk, dim_model)
518 self.w_O = randw(dim_v * nb_heads, dim_model)
520 self.init_K_rec = randw(
525 self.init_V_rec = randw(
531 # def reset_inner_loss(self):
532 # self.acc_attention = 0
535 # def get_inner_loss(self):
536 # warnings.warn("l2 regularization", RuntimeWarning)
537 # return (self.acc_attention / self.acc_nb).pow(2).sum()
538 # return torch.tensor([0], device=self.w_Q.device)
540 def forward(self, bs):
541 # Dimensions to make the source a bit clearer, that's needed
543 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
548 DV = self.w_V.size(1)
549 DK = self.w_K.size(1)
550 DM = self.w_O.size(1)
551 R = self.caterpillar_height
552 L = self.caterpillar_length
555 t0 >= L and (t1 - t0) % L == 0
556 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
558 # We cache values to deal efficiently with auto-regression
561 self.rec_V = X.new_zeros(N, R, T, DV)
562 self.rec_K = X.new_zeros(N, R, T, DK)
563 # We start the recurrent sequences with optimizable
564 # initial values. No idea if it helps.
565 self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
566 self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
568 self.cache_Y = X.new_zeros(N, T, DM)
570 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
571 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
573 ######################################################################
574 # Compute the recurrent state
576 # This is the Gating sequence that modulates the storing of
577 # the new key and value in the R pairs of the current
578 # stack. There are R independent gating values, which means
579 # that the current K/V may be stored in multiple pairs of the
580 # recurrent state, or not at all.
583 torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
586 # Clip the gating to avoid values greater than 1 when several
587 # heads hit the same row
589 G = G / G.sum(1, keepdim=True).clamp(min=1)
591 ######################################################################
593 def recurrence(G, V, K):
594 # We prepare the arguments for the parallel scan
598 gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
599 gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
601 # We start from cached values, which matters in inference
603 init_rec_V = self.rec_V[:, :, t0 - L : t0]
604 init_rec_K = self.rec_K[:, :, t0 - L : t0]
606 # Here there is a trick: Since the stack at position t is
607 # computed by updating that at position t-L, the parallel
608 # scan operates with a period of L. To do so we split the
609 # sequence indexing in two axes, the second of size L, and
610 # run the parallel scan using the first as the sequence index.
612 A = A.unflatten(2, (-1, L))
613 gated_V = gated_V.unflatten(2, (-1, L))
614 gated_K = gated_K.unflatten(2, (-1, L))
616 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
617 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
619 return next_V, next_K
621 #################################################################
623 next_V, next_K = recurrence(G, V, K)
625 if self.training and self.gate_dropout_proba > 0.0:
626 # G is NxHxRxT where r is the caterpillar's row.
628 warnings.warn("gate dropout", RuntimeWarning)
630 if self.gate_dropout_sync:
631 shape_kill = (N, 1, 1)
633 shape_kill = (N, H, R)
635 # Pick a point in each of the NxHxR timeline and set this
636 # entry and the following to 1
638 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
642 # Keep these mask for only some of the NxHxR
644 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
647 # The coefficient to keep are the complementary
650 masked_next_V, masked_next_K = recurrence(G * mask, V, K)
652 if self.gate_dropout_replace:
653 next_V = next_V.detach()
654 next_K = next_K.detach()
656 next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
657 1 - self.gate_dropout_proba
659 next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
660 1 - self.gate_dropout_proba
663 self.rec_V[:, :, t0:t1] = next_V
664 self.rec_K[:, :, t0:t1] = next_K
666 ######################################################################
667 # compute the readout
669 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
671 # We build tensors NxHxTxRxL where N is the sample index, H
672 # the head, T the time, R the row in the caterpillar, and L
673 # the column in the caterpillar
675 windowed_V = moving_window(
676 self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
679 windowed_K = moving_window(
680 self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
683 # We have an attention score for each of the RxL values
691 # softmax can operate only on one dimension, hence the
694 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
696 ar = F.dropout(ar, self.attention_dropout, self.training)
698 # Compute the output for each head, flatten to concatenate
706 # Compute the final output
708 self.cache_Y[:, t0:t1] = Y @ self.w_O
710 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
713 ##############################
716 class QKVAttention(nn.Module):
724 attention_dropout=0.0,
731 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
734 self.attention_dropout = attention_dropout
735 self.record_attention = False
737 self.w_q = randw(nb_heads, dim_qk, dim_model)
738 self.w_k = randw(nb_heads, dim_qk, dim_model)
739 self.w_v = randw(nb_heads, dim_v, dim_model)
740 self.w_o = randw(dim_v * nb_heads, dim_model)
742 def forward(self, bs):
746 self.causal or bs.complete()
747 ), "Partial evaluation is only possible for causal models"
750 self.cache_k = x_q.new_zeros(
751 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
753 self.cache_v = x_q.new_zeros(
754 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
756 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
758 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
760 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
761 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
763 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
764 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
768 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
769 ) / math.sqrt(self.w_q.size(1))
773 self.cache_attzero = (
774 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
775 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
779 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
786 if self.record_attention:
789 a = F.dropout(a, self.attention_dropout, self.training)
792 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
795 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
797 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
800 ##############################
803 class MyGPT(nn.Module):
813 caterpillar_height=None,
817 attention_layer="kvrec",
823 assert attention_layer in {
828 }, f"Unknown attention operator {attention_layer}."
830 if attention_layer == "caterpillar":
831 assert nb_lines % caterpillar_height == 0
832 self.caterpillar_length = nb_lines // caterpillar_height
833 self.caterpillar_height = caterpillar_height
835 self.caterpillar_length = -1
836 self.caterpillar_height = -1
838 assert dim_model % nb_heads == 0
840 self.embedding = nn.Sequential(
841 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
842 AddPositionalEncoding(len_max),
848 if attention_layer == "mha":
852 dim_v=dim_model // nb_heads,
855 attention_dropout=dropout,
859 elif attention_layer == "dumbrec":
863 dim_v=dim_model // nb_heads,
866 attention_dropout=dropout,
870 elif attention_layer == "kvrec":
874 dim_v=dim_model // nb_heads,
877 attention_dropout=dropout,
881 elif attention_layer == "caterpillar":
885 dim_v=dim_model // nb_heads,
887 caterpillar_length=self.caterpillar_length,
888 caterpillar_height=self.caterpillar_height,
889 attention_dropout=dropout,
894 raise ValueError(f"Unknown attention type {attention_layer}.")
896 for b in range(nb_blocks):
899 CacheWrapper(nn.LayerNorm((dim_model,))),
904 nn.LayerNorm((dim_model,)),
905 nn.Linear(in_features=dim_model, out_features=dim_hidden),
907 nn.Linear(in_features=dim_hidden, out_features=dim_model),
913 self.trunk = nn.Sequential(*trunk_blocks)
915 self.readout = CacheWrapper(
916 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
919 with torch.no_grad():
920 for m in self.modules():
921 if isinstance(m, nn.Embedding):
922 m.weight.normal_(mean=0, std=2e-2)
923 elif isinstance(m, nn.LayerNorm):
927 self.reset_inner_loss()
929 def forward(self, bs):
930 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
932 # To make the code simpler in the Caterpillar layer, we pad
933 # here. It's unclear if/how much it hurts computationaly by
934 # increasing the sequence length for the other layers
936 if self.caterpillar_length > 0:
938 if bs.nb % self.caterpillar_length > 0:
939 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
941 bs = BracketedSequence(
942 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
943 bs.first + self.caterpillar_length,
948 bs = self.embedding(bs)
950 bs = self.readout(bs)
952 if self.caterpillar_length > 0:
953 bs = BracketedSequence(
954 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
955 bs.first - self.caterpillar_length,
962 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
963 # 1s where tokens should be generated. The others are kept
966 def masked_inplace_autoregression(
970 forbidden_tokens=None,
971 deterministic_synthesis=False,
973 input = input_src.to(self.readout.f.weight.device)
974 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
975 to_generate = (ar_mask.sum(0) > 0).nonzero()
976 if to_generate.min() > 0:
978 BracketedSequence(input, 0, to_generate.min(), True)
979 ) # Needed to initialize the model's cache
980 for s in range(to_generate.min(), to_generate.max() + 1):
981 output = self(BracketedSequence(input, s, 1, s == 0)).x
982 logits = output[:, s]
983 if forbidden_tokens is not None:
984 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
985 if deterministic_synthesis:
986 t_next = logits.argmax(1)
988 dist = torch.distributions.categorical.Categorical(logits=logits)
989 t_next = dist.sample()
990 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
992 input_src.copy_(input)
994 def reset_inner_loss(self):
995 for m in self.modules():
996 if m is not self and hasattr(m, "reset_inner_loss"):
999 def get_inner_loss(self):
1000 l = torch.tensor([0.0], device=self.readout.f.weight.device)
1001 for m in self.modules():
1002 if m is not self and hasattr(m, "get_inner_loss"):
1003 l += m.get_inner_loss()
1006 def record_attention(self, v=True):
1007 for m in self.modules():
1008 if isinstance(m, QKVAttention):
1009 m.record_attention = v
1011 def retrieve_attention(self):
1013 for m in self.modules():
1014 if isinstance(m, QKVAttention):
1019 ######################################################################
1021 if __name__ == "__main__":
1022 print("Basic check.")
1029 caterpillar_length=7,
1030 caterpillar_height=3,
1031 attention_dropout=0.0,
1034 m.reset_inner_loss()
1035 x = torch.randn(1, 21 + 2 * 7, 4)
1036 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1037 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1038 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1039 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1040 print((y1 - y2).abs().max())
1041 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1044 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1046 vocabulary_size = 128
1047 x = torch.randint(vocabulary_size, (6, 1024))
1050 vocabulary_size=vocabulary_size,
1066 # import torchvision.models as models
1067 # from torch.profiler import profile, record_function, ProfilerActivity
1069 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1070 # with record_function("model_inference"):
1074 start_time = time.perf_counter()
1076 model(BracketedSequence(x))
1077 duration = time.perf_counter() - start_time
1081 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1082 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1084 # print("##############################################################")
1085 # y2 = torch.randn_like(y1)
1086 # for s in range(x.size(1)):
1087 # z = model(BracketedSequence(x, s, 1))
1088 # y2[:, s : s + 1] = z.slice()
1090 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1092 ######################################################################