3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
13 # This implementation is equipped with RNN layers to replace the MHA
20 from torch.nn import functional as F
26 ######################################################################
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
45 class BracketedSequence:
46 def __init__(self, x, first=None, nb=None, init_cache=None):
48 assert (first is None and nb is None and init_cache is None) or (
49 first is not None and nb is not None and init_cache is not None
52 self.first = 0 if first is None else first
53 self.nb = x.size(1) if nb is None else nb
54 self.init_cache = True if init_cache is None else init_cache
57 return self.x[:, self.first : self.first + self.nb]
60 return self.first == 0 and self.nb == self.x.size(1)
63 ######################################################################
66 class CacheWrapper(nn.Module):
67 def __init__(self, *f):
69 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
71 def forward(self, bs):
73 y = self.f(bs.slice())
74 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75 self.cache_y[:, bs.first : bs.first + bs.nb] = y
77 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78 assert bs.first + bs.nb <= self.cache_y.size(1)
79 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
81 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84 ##############################
87 class WithResidual(nn.Module):
88 def __init__(self, *f):
90 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
92 def forward(self, bs):
93 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
96 ##############################
99 class AddPositionalEncoding(nn.Module):
100 def __init__(self, len_max):
102 self.len_max = len_max
104 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
106 def forward(self, bs):
108 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
111 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
116 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
118 self.cache_y = bs.x.new(bs.x.size())
120 self.cache_y[:, bs.first : bs.first + bs.nb] = (
121 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
124 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
130 # X is /.../xTxD A is /.../xT Y_init is /.../xD
133 def pscan_dim(A, X, Y_init, dim=-2):
135 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
137 A = A.reshape(a, T, *s[dim + 1 : -1])
138 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
141 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
143 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
145 Y = pscan.pscan(A, X, Y_init).reshape(s)
150 def pscan_shape(A, X, Y_init):
152 A = A.reshape(-1, s[-2])
153 X = X.reshape(-1, s[-2], s[-1])
156 Y_init = X.new_zeros(X.size(0), s[-1])
158 Y_init = Y_init.reshape(-1, s[-1])
160 Y = pscan.pscan(A, X, Y_init).reshape(s)
165 def nsum_shape(X, Y_init):
167 X = X.reshape(-1, s[-2], s[-1]) # ntd
169 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
172 for k in range(X.size(1)):
174 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
177 return torch.cat(result, dim=1).reshape(s)
180 ##############################
183 class DumbRec(nn.Module):
191 attention_dropout=0.0,
199 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
201 self.nb_lines = nb_lines
202 self.attention_dropout = attention_dropout
204 self.k_star = randw(nb_lines, dim_qk)
206 self.w_qw = randw(nb_heads, dim_qk, dim_model)
207 self.w_qr = randw(nb_heads, dim_qk, dim_model)
208 # self.w_k = randw(nb_heads, dim_qk, dim_model)
209 self.w_v = randw(nb_heads, dim_v, dim_model)
210 self.w_o = randw(dim_v * nb_heads, dim_model)
212 def reset_inner_loss(self):
213 self.acc_attention = 0
216 def get_inner_loss(self):
217 warnings.warn("l2 regularization", RuntimeWarning)
218 return (self.acc_attention / self.acc_nb).pow(2).sum()
219 # return torch.tensor([0], device=self.w_qw.device)
221 def forward(self, bs):
222 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
225 self.rec_v = x_q.new_zeros(
226 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
228 # self.rec_k = x_q.new_zeros(
229 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
231 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
233 ######################################################################
236 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
238 warnings.warn("rotating key barrel", RuntimeWarning)
239 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
240 t_barrel = torch.arange(t0, t1, device=k_star.device)
241 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
243 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
245 k_star = k_star[l_barrel, t_barrel]
247 ######################################################################
248 # Compute the recurrent state
250 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
252 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
253 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
259 ) / math.sqrt(self.w_qw.size(1))
261 aw = aw.softmax(dim=2) # nhlt
264 self.acc_attention += aw.sum(dim=(0, 1, 3))
265 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
267 aw = F.dropout(aw, self.attention_dropout, self.training)
269 A = 1 - aw.sum(dim=1) # nlt
271 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
272 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
278 V0 = self.rec_v[:, :, t0 - 1]
279 # K0 = self.rec_k[:, :, t0 - 1]
281 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
282 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
284 ######################################################################
285 # compute the readout
287 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
292 # self.rec_k[:, :, t0:t1],
294 ) / math.sqrt(self.w_qr.size(1))
296 ar = ar.softmax(dim=2) # nhlt
298 ar = F.dropout(ar, self.attention_dropout, self.training)
303 self.rec_v[:, :, t0:t1],
306 self.cache_y[:, t0:t1] = y @ self.w_o
308 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
311 ##############################
314 class KVRec(nn.Module):
322 attention_dropout=0.0,
330 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
332 self.nb_lines = nb_lines
333 self.attention_dropout = attention_dropout
335 self.k_star = randw(nb_lines, dim_qk)
337 self.w_qw = randw(nb_heads, dim_qk, dim_model)
338 self.w_qr = randw(nb_heads, dim_qk, dim_model)
339 self.w_k = randw(nb_heads, dim_qk, dim_model)
340 self.w_v = randw(nb_heads, dim_v, dim_model)
341 self.w_o = randw(dim_v * nb_heads, dim_model)
343 def reset_inner_loss(self):
344 self.acc_attention = 0
347 def get_inner_loss(self):
348 warnings.warn("l2 regularization", RuntimeWarning)
349 return (self.acc_attention / self.acc_nb).pow(2).sum()
350 # return torch.tensor([0], device=self.w_qw.device)
351 # warnings.warn("side regularization", RuntimeWarning)
353 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
355 # return torch.tensor([0], device=self.w_qw.device)
357 def forward(self, bs):
358 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
361 self.rec_v = x_q.new_zeros(
362 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
364 self.rec_k = x_q.new_zeros(
365 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
367 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
369 ######################################################################
372 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
374 warnings.warn("rotating key barrel", RuntimeWarning)
375 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
376 t_barrel = torch.arange(t0, t1, device=k_star.device)
377 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
379 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
381 k_star = k_star[l_barrel, t_barrel]
383 ######################################################################
384 # Compute the recurrent state
386 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
388 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
389 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
395 ) / math.sqrt(self.w_qw.size(1))
397 aw = aw.softmax(dim=2) # nhlt
400 # We want all the memory lines to be used similarly
401 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
402 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
404 aw = F.dropout(aw, self.attention_dropout, self.training)
406 A = 1 - aw.sum(dim=1) # nlt
408 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
409 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
415 V0 = self.rec_v[:, :, t0 - 1]
416 K0 = self.rec_k[:, :, t0 - 1]
418 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
419 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
421 ######################################################################
422 # compute the readout
424 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
429 self.rec_k[:, :, t0:t1],
430 ) / math.sqrt(self.w_qr.size(1))
432 ar = ar.softmax(dim=2) # nhlt
434 ar = F.dropout(ar, self.attention_dropout, self.training)
439 self.rec_v[:, :, t0:t1],
442 self.cache_y[:, t0:t1] = y @ self.w_o
444 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
447 ##############################
450 # Returns a tensor with an additional index at rank win_dim, that move
451 # along the same dimension as dim, on a domain {0...win_size-1}, and
452 # dim is restricted on a domain reduced by win_size-1 values.
455 def moving_window(x, dim, win_dim, win_size):
456 size, stride = x.size(), x.stride()
457 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
458 size = size[:win_dim] + (win_size,) + size[win_dim:]
459 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
461 return x.as_strided(size=size, stride=stride)
464 ##############################
468 def __init__(self, w=None, b=None):
471 self.s, self.s_sq, self.n = 0, 0, 0
472 self.mean, self.std = 0, 0
476 self.s += X.sum(dim=0)
477 self.s_sq += X.pow(2).sum(dim=0)
481 mean = self.s / self.n
482 std = (self.s_sq / self.n - mean * mean).sqrt()
486 mean, std = self.moments()
487 if self.b is not None:
489 if self.w is not None:
491 result = mean - self.mean, std - self.std
492 self.mean, self.std = mean, std
493 self.s, self.s_sq, self.n = 0, 0, 0
497 class Caterpillar(nn.Module):
506 attention_dropout=0.0,
513 warnings.warn("Caterpillar", RuntimeWarning)
515 def randw(*d, amplitude=None):
516 if amplitude is None:
517 amplitude = 1 / math.sqrt(d[-1])
518 return nn.Parameter(amplitude * torch.randn(*d))
520 self.caterpillar_length = caterpillar_length
521 self.caterpillar_height = caterpillar_height
522 self.attention_dropout = attention_dropout
524 ######################################################################
527 x = kwargs.get("gate_dropout")
529 self.proba_gate_dropout = 0.0
531 self.proba_gate_dropout = float(x)
533 logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
535 x = kwargs.get("default_bg")
537 default_bg = -math.log(caterpillar_height - 1)
539 default_bg = float(x)
541 logger(f"default_bg {default_bg}")
543 ######################################################################
545 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
546 self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
548 self.w_K = randw(nb_heads, dim_qk, dim_model)
549 self.w_V = randw(nb_heads, dim_v, dim_model)
550 self.w_Q = randw(nb_heads, dim_qk, dim_model)
551 self.w_O = randw(dim_v * nb_heads, dim_model)
553 self.init_K_rec = randw(
558 self.init_V_rec = randw(
564 self.calibrator_G = Calibrator()
565 self.calibrator_rec_V = Calibrator()
566 self.calibrator_rec_K = Calibrator()
568 def reset_inner_loss(self):
569 self.acc_attention = 0
572 def get_inner_loss(self):
573 # warnings.warn("l2 regularization", RuntimeWarning)
574 # return (self.acc_attention / self.acc_nb).pow(2).sum()
575 return torch.tensor([0], device=self.w_Q.device)
577 def forward(self, bs):
578 # Dimensions to make the source a bit clearer, that's needed
580 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
585 DV = self.w_V.size(1)
586 DK = self.w_K.size(1)
587 DM = self.w_O.size(1)
588 R = self.caterpillar_height
589 L = self.caterpillar_length
592 t0 >= L and (t1 - t0) % L == 0
593 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
595 # We cache values to deal efficiently with auto-regression
598 self.rec_V = X.new_zeros(N, R, T, DV)
599 self.rec_K = X.new_zeros(N, R, T, DK)
600 # We start the recurrent sequences with optimizable
601 # initial values. No idea if it helps.
602 self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
603 self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
605 self.cache_Y = X.new_zeros(N, T, DM)
607 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
608 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
610 ######################################################################
611 # Compute the recurrent state
613 # This is the Gating sequence that modulates the storing of
614 # the new key and value in the R pairs of the current
615 # stack. There are R independent gating values, which means
616 # that the current K/V may be stored in multiple pairs of the
617 # recurrent state, or not at all.
620 torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
623 self.calibrator_G.update(G.reshape(-1, G.size(-1)))
625 # warnings.warn("softmax gating", RuntimeWarning)
628 # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
631 ######################################################################
634 if self.training and self.proba_gate_dropout > 0.0:
635 # This is a better implementation of "flashbacks".
637 # G is NxHxExT where e is the caterpillar's row.
639 warnings.warn("gate dropout", RuntimeWarning)
642 torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
645 alpha = G / (1 - self.proba_gate_dropout)
647 G = alpha * (1 - kill)
649 ######################################################################
650 # Clip the gating to avoid values greater than 1 when several
651 # heads hit the same row
653 G = G / G.sum(1, keepdim=True).clamp(min=1)
655 ######################################################################
656 # Roll the gating indexes
658 # warnings.warn("rotating barrel", RuntimeWarning)
660 # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
661 # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
662 # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
663 # G = G.gather(dim=2, index=r_barrel.expand_as(G))
665 # We prepare the arguments for the parallel scan
669 # warnings.warn("harmonic recurrence", RuntimeWarning)
670 # har = torch.arange(t0, t1, device = G.device).float() + 1
671 # A = har / (har + 1)
674 gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
675 gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
677 # We start from cached values, which matters in inference
679 init_rec_V = self.rec_V[:, :, t0 - L : t0]
680 init_rec_K = self.rec_K[:, :, t0 - L : t0]
682 #################################################################
685 # Here there is a trick: Since the stack at position t is
686 # computed by updating that at position t-L, the parallel
687 # scan operates with a period of L. To do so we split the
688 # sequence indexing in two axes, the second of size L, and
689 # run the parallel scan using the first as the sequence index.
691 A = A.unflatten(2, (-1, L))
692 gated_V = gated_V.unflatten(2, (-1, L))
693 gated_K = gated_K.unflatten(2, (-1, L))
695 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
696 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
698 next_V = next_V.flatten(2, 3)
699 next_K = next_K.flatten(2, 3)
701 self.calibrator_rec_V.update(
702 next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
704 self.calibrator_rec_K.update(
705 next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
708 self.rec_V[:, :, t0:t1] = next_V
709 self.rec_K[:, :, t0:t1] = next_K
711 ######################################################################
712 # compute the readout
714 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
716 # We build tensors NxHxTxFxL where N is the sample index, H
717 # the head, T the time, F the row in the caterpillar, and L
718 # the column in the caterpillar
720 windowed_V = moving_window(
721 self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
724 windowed_K = moving_window(
725 self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
728 # We have an attention score for each of the RxL values
736 # softmax can operate only on one dimension, hence the
739 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
741 ar = F.dropout(ar, self.attention_dropout, self.training)
743 # Compute the output for each head, flatten to concatenate
751 # Compute the final output
753 self.cache_Y[:, t0:t1] = Y @ self.w_O
755 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
758 ##############################
761 class QKVAttention(nn.Module):
769 attention_dropout=0.0,
776 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
779 self.attention_dropout = attention_dropout
780 self.record_attention = False
782 self.w_q = randw(nb_heads, dim_qk, dim_model)
783 self.w_k = randw(nb_heads, dim_qk, dim_model)
784 self.w_v = randw(nb_heads, dim_v, dim_model)
785 self.w_o = randw(dim_v * nb_heads, dim_model)
787 def forward(self, bs):
791 self.causal or bs.complete()
792 ), "Partial evaluation is only possible for causal models"
795 self.cache_k = x_q.new_zeros(
796 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
798 self.cache_v = x_q.new_zeros(
799 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
801 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
803 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
805 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
806 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
808 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
809 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
813 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
814 ) / math.sqrt(self.w_q.size(1))
818 self.cache_attzero = (
819 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
820 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
824 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
831 if self.record_attention:
834 a = F.dropout(a, self.attention_dropout, self.training)
837 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
840 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
842 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
845 ##############################
848 class MyGPT(nn.Module):
858 caterpillar_height=None,
862 attention_layer="kvrec",
868 assert attention_layer in {
873 }, f"Unknown attention operator {attention_layer}."
875 if attention_layer == "caterpillar":
876 assert nb_lines % caterpillar_height == 0
877 self.caterpillar_length = nb_lines // caterpillar_height
878 self.caterpillar_height = caterpillar_height
880 self.caterpillar_length = -1
881 self.caterpillar_height = -1
883 assert dim_model % nb_heads == 0
885 self.embedding = nn.Sequential(
886 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
887 AddPositionalEncoding(len_max),
893 if attention_layer == "mha":
897 dim_v=dim_model // nb_heads,
900 attention_dropout=dropout,
904 elif attention_layer == "dumbrec":
908 dim_v=dim_model // nb_heads,
911 attention_dropout=dropout,
915 elif attention_layer == "kvrec":
919 dim_v=dim_model // nb_heads,
922 attention_dropout=dropout,
926 elif attention_layer == "caterpillar":
930 dim_v=dim_model // nb_heads,
932 caterpillar_length=self.caterpillar_length,
933 caterpillar_height=self.caterpillar_height,
934 attention_dropout=dropout,
939 raise ValueError(f"Unknown attention type {attention_layer}.")
941 for b in range(nb_blocks):
944 CacheWrapper(nn.LayerNorm((dim_model,))),
949 nn.LayerNorm((dim_model,)),
950 nn.Linear(in_features=dim_model, out_features=dim_hidden),
952 nn.Linear(in_features=dim_hidden, out_features=dim_model),
958 self.trunk = nn.Sequential(*trunk_blocks)
960 self.readout = CacheWrapper(
961 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
964 with torch.no_grad():
965 for m in self.modules():
966 if isinstance(m, nn.Embedding):
967 m.weight.normal_(mean=0, std=2e-2)
968 elif isinstance(m, nn.LayerNorm):
972 self.reset_inner_loss()
974 def forward(self, bs):
975 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
977 # To make the code simpler in the Caterpillar layer, we pad
978 # here. It's unclear if/how much it hurts computationaly by
979 # increasing the sequence length for the other layers
981 if self.caterpillar_length > 0:
983 if bs.nb % self.caterpillar_length > 0:
984 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
986 bs = BracketedSequence(
987 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
988 bs.first + self.caterpillar_length,
993 bs = self.embedding(bs)
995 bs = self.readout(bs)
997 if self.caterpillar_length > 0:
998 bs = BracketedSequence(
999 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
1000 bs.first - self.caterpillar_length,
1007 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
1008 # 1s where tokens should be generated. The others are kept
1011 def masked_inplace_autoregression(
1015 forbidden_tokens=None,
1016 deterministic_synthesis=False,
1018 input = input_src.to(self.readout.f.weight.device)
1019 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
1020 to_generate = (ar_mask.sum(0) > 0).nonzero()
1021 if to_generate.min() > 0:
1023 BracketedSequence(input, 0, to_generate.min(), True)
1024 ) # Needed to initialize the model's cache
1025 for s in range(to_generate.min(), to_generate.max() + 1):
1026 output = self(BracketedSequence(input, s, 1, s == 0)).x
1027 logits = output[:, s]
1028 if forbidden_tokens is not None:
1029 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
1030 if deterministic_synthesis:
1031 t_next = logits.argmax(1)
1033 dist = torch.distributions.categorical.Categorical(logits=logits)
1034 t_next = dist.sample()
1035 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1037 input_src.copy_(input)
1039 def reset_inner_loss(self):
1040 for m in self.modules():
1041 if m is not self and hasattr(m, "reset_inner_loss"):
1042 m.reset_inner_loss()
1044 def get_inner_loss(self):
1045 l = torch.tensor([0.0], device=self.readout.f.weight.device)
1046 for m in self.modules():
1047 if m is not self and hasattr(m, "get_inner_loss"):
1048 l += m.get_inner_loss()
1051 def record_attention(self, v=True):
1052 for m in self.modules():
1053 if isinstance(m, QKVAttention):
1054 m.record_attention = v
1056 def retrieve_attention(self):
1058 for m in self.modules():
1059 if isinstance(m, QKVAttention):
1064 ######################################################################
1066 if __name__ == "__main__":
1067 print("Basic check.")
1074 caterpillar_length=7,
1075 caterpillar_height=3,
1076 attention_dropout=0.0,
1079 m.reset_inner_loss()
1080 x = torch.randn(1, 21 + 2 * 7, 4)
1081 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1082 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1083 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1084 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1085 print((y1 - y2).abs().max())
1086 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1089 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1091 vocabulary_size = 128
1092 x = torch.randint(vocabulary_size, (6, 1024))
1095 vocabulary_size=vocabulary_size,
1111 # import torchvision.models as models
1112 # from torch.profiler import profile, record_function, ProfilerActivity
1114 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1115 # with record_function("model_inference"):
1119 start_time = time.perf_counter()
1121 model(BracketedSequence(x))
1122 duration = time.perf_counter() - start_time
1126 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1127 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1129 # print("##############################################################")
1130 # y2 = torch.randn_like(y1)
1131 # for s in range(x.size(1)):
1132 # z = model(BracketedSequence(x, s, 1))
1133 # y2[:, s : s + 1] = z.slice()
1135 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1137 ######################################################################