3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
18 from torch.nn import functional as F
24 ######################################################################
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
43 class BracketedSequence:
44 def __init__(self, x, first=None, nb=None, init_cache=None):
46 assert (first is None and nb is None and init_cache is None) or (
47 first is not None and nb is not None and init_cache is not None
50 self.first = 0 if first is None else first
51 self.nb = x.size(1) if nb is None else nb
52 self.init_cache = True if init_cache is None else init_cache
55 return self.x[:, self.first : self.first + self.nb]
58 return self.first == 0 and self.nb == self.x.size(1)
61 ######################################################################
64 class CacheWrapper(nn.Module):
65 def __init__(self, *f):
67 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
69 def forward(self, bs):
71 y = self.f(bs.slice())
72 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73 self.cache_y[:, bs.first : bs.first + bs.nb] = y
75 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76 assert bs.first + bs.nb <= self.cache_y.size(1)
77 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
79 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82 ##############################
85 class WithResidual(nn.Module):
86 def __init__(self, *f):
88 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
90 def forward(self, bs):
91 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94 ##############################
97 class AddPositionalEncoding(nn.Module):
98 def __init__(self, len_max):
100 self.len_max = len_max
102 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
104 def forward(self, bs):
106 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
114 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
116 self.cache_y = bs.x.new(bs.x.size())
118 self.cache_y[:, bs.first : bs.first + bs.nb] = (
119 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
128 # X is /.../xTxD A is /.../xT Y_init is /.../xD
131 def pscan_dim(A, X, Y_init, dim=-2):
133 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135 A = A.reshape(a, T, *s[dim + 1 : -1])
136 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143 Y = pscan.pscan(A, X, Y_init).reshape(s)
148 def pscan_shape(A, X, Y_init):
150 A = A.reshape(-1, s[-2])
151 X = X.reshape(-1, s[-2], s[-1])
154 Y_init = X.new_zeros(X.size(0), s[-1])
156 Y_init = Y_init.reshape(-1, s[-1])
158 Y = pscan.pscan(A, X, Y_init).reshape(s)
163 def nsum_shape(X, Y_init):
165 X = X.reshape(-1, s[-2], s[-1]) # ntd
167 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170 for k in range(X.size(1)):
172 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175 return torch.cat(result, dim=1).reshape(s)
178 ##############################
181 class DumbRec(nn.Module):
189 attention_dropout=0.0,
195 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
197 self.nb_lines = nb_lines
198 self.attention_dropout = attention_dropout
200 self.k_star = randw(nb_lines, dim_qk)
202 self.w_qw = randw(nb_heads, dim_qk, dim_model)
203 self.w_qr = randw(nb_heads, dim_qk, dim_model)
204 # self.w_k = randw(nb_heads, dim_qk, dim_model)
205 self.w_v = randw(nb_heads, dim_v, dim_model)
206 self.w_o = randw(dim_v * nb_heads, dim_model)
208 def reset_inner_loss(self):
209 self.acc_attention = 0
212 def get_inner_loss(self):
213 warnings.warn("l2 regularization", RuntimeWarning)
214 return (self.acc_attention / self.acc_nb).pow(2).sum()
215 # return torch.tensor([0], device=self.w_qw.device)
217 def forward(self, bs):
218 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221 self.rec_v = x_q.new_zeros(
222 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
224 # self.rec_k = x_q.new_zeros(
225 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
227 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
229 ######################################################################
232 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
234 warnings.warn("rotating key barrel", RuntimeWarning)
235 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236 t_barrel = torch.arange(t0, t1, device=k_star.device)
237 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
239 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
241 k_star = k_star[l_barrel, t_barrel]
243 ######################################################################
244 # Compute the recurrent state
246 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
248 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
255 ) / math.sqrt(self.w_qw.size(1))
257 aw = aw.softmax(dim=2) # nhlt
260 self.acc_attention += aw.sum(dim=(0, 1, 3))
261 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
263 aw = F.dropout(aw, self.attention_dropout, self.training)
265 A = 1 - aw.sum(dim=1) # nlt
267 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
274 V0 = self.rec_v[:, :, t0 - 1]
275 # K0 = self.rec_k[:, :, t0 - 1]
277 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
280 ######################################################################
281 # compute the readout
283 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288 # self.rec_k[:, :, t0:t1],
290 ) / math.sqrt(self.w_qr.size(1))
292 ar = ar.softmax(dim=2) # nhlt
294 ar = F.dropout(ar, self.attention_dropout, self.training)
299 self.rec_v[:, :, t0:t1],
302 self.cache_y[:, t0:t1] = y @ self.w_o
304 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307 ##############################
310 class KVRec(nn.Module):
318 attention_dropout=0.0,
324 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
326 self.nb_lines = nb_lines
327 self.attention_dropout = attention_dropout
329 self.k_star = randw(nb_lines, dim_qk)
331 self.w_qw = randw(nb_heads, dim_qk, dim_model)
332 self.w_qr = randw(nb_heads, dim_qk, dim_model)
333 self.w_k = randw(nb_heads, dim_qk, dim_model)
334 self.w_v = randw(nb_heads, dim_v, dim_model)
335 self.w_o = randw(dim_v * nb_heads, dim_model)
337 def reset_inner_loss(self):
338 self.acc_attention = 0
341 def get_inner_loss(self):
342 warnings.warn("l2 regularization", RuntimeWarning)
343 return (self.acc_attention / self.acc_nb).pow(2).sum()
344 # return torch.tensor([0], device=self.w_qw.device)
345 # warnings.warn("side regularization", RuntimeWarning)
347 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
349 # return torch.tensor([0], device=self.w_qw.device)
351 def forward(self, bs):
352 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355 self.rec_v = x_q.new_zeros(
356 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
358 self.rec_k = x_q.new_zeros(
359 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
361 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
363 ######################################################################
366 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
368 warnings.warn("rotating key barrel", RuntimeWarning)
369 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370 t_barrel = torch.arange(t0, t1, device=k_star.device)
371 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
373 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
375 k_star = k_star[l_barrel, t_barrel]
377 ######################################################################
378 # Compute the recurrent state
380 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
382 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
389 ) / math.sqrt(self.w_qw.size(1))
391 aw = aw.softmax(dim=2) # nhlt
394 # We want all the memory lines to be used similarly
395 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
396 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
398 aw = F.dropout(aw, self.attention_dropout, self.training)
400 A = 1 - aw.sum(dim=1) # nlt
402 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
409 V0 = self.rec_v[:, :, t0 - 1]
410 K0 = self.rec_k[:, :, t0 - 1]
412 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
415 ######################################################################
416 # compute the readout
418 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
423 self.rec_k[:, :, t0:t1],
424 ) / math.sqrt(self.w_qr.size(1))
426 ar = ar.softmax(dim=2) # nhlt
428 ar = F.dropout(ar, self.attention_dropout, self.training)
433 self.rec_v[:, :, t0:t1],
436 self.cache_y[:, t0:t1] = y @ self.w_o
438 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441 ##############################
444 def moving_window(x, dim, win_dim, win_size):
445 size, stride = x.size(), x.stride()
446 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
447 size = size[:win_dim] + (win_size,) + size[win_dim:]
448 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
450 return x.as_strided(size=size, stride=stride)
453 ##############################
456 class Caterpillar(nn.Module):
465 attention_dropout=0.0,
470 warnings.warn("Caterpillar", RuntimeWarning)
473 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
475 self.caterpillar_length = caterpillar_length
476 self.caterpillar_height = caterpillar_height
477 self.attention_dropout = attention_dropout
479 self.w_G = randw(nb_heads, caterpillar_height, dim_model)
480 self.b_G = nn.Parameter(
482 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
486 self.w_K = randw(nb_heads, dim_qk, dim_model)
487 self.w_V = randw(nb_heads, dim_v, dim_model)
488 self.w_Q = randw(nb_heads, dim_qk, dim_model)
489 self.w_O = randw(dim_v * nb_heads, dim_model)
491 self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
492 self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
494 def reset_inner_loss(self):
495 self.acc_attention = 0
498 def get_inner_loss(self):
499 # warnings.warn("l2 regularization", RuntimeWarning)
500 # return (self.acc_attention / self.acc_nb).pow(2).sum()
501 return torch.tensor([0], device=self.w_Q.device)
503 def forward(self, bs):
504 # Dimensions to make the source a bit clearer, that's needed
506 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
510 DV = self.w_V.size(1)
511 DK = self.w_K.size(1)
512 Dout = self.w_O.size(1)
513 CH = self.caterpillar_height
514 CL = self.caterpillar_length
517 t0 >= CL and (t1 - t0) % CL == 0
518 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
521 self.rec_V = X.new_zeros(N, CH, T, DV)
522 self.rec_K = X.new_zeros(N, CH, T, DK)
523 # We start the recurrent sequences with optimizable
524 # initial values. No idea if it helps.
525 self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
526 self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
528 self.cache_Y = X.new_zeros(N, T, Dout)
530 ######################################################################
531 # Compute the recurrent state
533 # This is the Gating sequence that modulates if they key and
534 # values should be stored in one of the CH pairs of the
535 # current stack. The CH gating values are independent, which
536 # means that the same thing could be stored up to CH times or
540 torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
543 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
544 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
546 # We prepare the arguments for the parallel scan
549 gated_V = torch.einsum("nhet,nhtd->netd", G, V)
550 gated_K = torch.einsum("nhet,nhtd->netd", G, K)
552 init_rec_V = self.rec_V[:, :, t0 - CL : t0]
553 init_rec_K = self.rec_K[:, :, t0 - CL : t0]
555 # Here there is a trick: The parallel scan operates with a
556 # period of L, so we split the sequence indexing in two axes,
557 # the second of size CL, and run the parallel scan using the
558 # other alone as the sequence index.
560 A = A.unflatten(2, (-1, CL))
561 gated_V = gated_V.unflatten(2, (-1, CL))
562 gated_K = gated_K.unflatten(2, (-1, CL))
564 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
565 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
567 # Put back the sequence index
569 self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
570 self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
572 ######################################################################
573 # compute the readout
575 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
577 # We build tensors NxHxTxFxL where N is the sample index, H
578 # the head, T the time, F the row in the caterpillar, and L
579 # the column in the caterpillar
581 windowed_V = moving_window(
582 self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
585 windowed_K = moving_window(
586 self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
589 # We have an attention score for each of the CHxCL values
597 # softmax can operate only on one dimension, hence the
600 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
602 ar = F.dropout(ar, self.attention_dropout, self.training)
604 # Compute the output for each head, flatten to concatenate
612 # Compute the final output
614 self.cache_Y[:, t0:t1] = Y @ self.w_O
616 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
619 ##############################
622 class QKVAttention(nn.Module):
630 attention_dropout=0.0,
635 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
638 self.attention_dropout = attention_dropout
639 self.record_attention = False
641 self.w_q = randw(nb_heads, dim_qk, dim_model)
642 self.w_k = randw(nb_heads, dim_qk, dim_model)
643 self.w_v = randw(nb_heads, dim_v, dim_model)
644 self.w_o = randw(dim_v * nb_heads, dim_model)
646 def forward(self, bs):
650 self.causal or bs.complete()
651 ), "Partial evaluation is only possible for causal models"
654 self.cache_k = x_q.new_zeros(
655 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
657 self.cache_v = x_q.new_zeros(
658 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
660 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
662 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
664 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
665 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
667 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
668 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
672 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
673 ) / math.sqrt(self.w_q.size(1))
677 self.cache_attzero = (
678 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
679 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
683 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
690 if self.record_attention:
693 a = F.dropout(a, self.attention_dropout, self.training)
696 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
699 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
701 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
704 ##############################
707 class MyGPT(nn.Module):
717 caterpillar_height=None,
722 attention_layer="kvrec",
726 assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
728 if attention_layer == "caterpillar":
729 assert nb_lines % caterpillar_height == 0
730 self.caterpillar_length = nb_lines // caterpillar_height
731 self.caterpillar_height = caterpillar_height
733 self.caterpillar_length = -1
734 self.caterpillar_height = -1
736 assert dim_model % nb_heads == 0
738 self.embedding = nn.Sequential(
739 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
740 AddPositionalEncoding(len_max),
746 if attention_layer == "mha":
750 dim_v=dim_model // nb_heads,
753 attention_dropout=dropout,
755 elif attention_layer == "dumbrec":
762 attention_dropout=dropout,
764 elif attention_layer == "kvrec":
771 attention_dropout=dropout,
773 elif attention_layer == "caterpillar":
779 caterpillar_length=self.caterpillar_length,
780 caterpillar_height=self.caterpillar_height,
781 attention_dropout=dropout,
784 raise ValueError(f"Unknown attention type {attention_layer}.")
786 for b in range(nb_blocks):
789 CacheWrapper(nn.LayerNorm((dim_model,))),
794 nn.LayerNorm((dim_model,)),
795 nn.Linear(in_features=dim_model, out_features=dim_hidden),
797 nn.Linear(in_features=dim_hidden, out_features=dim_model),
803 self.trunk = nn.Sequential(*trunk_blocks)
805 self.readout = CacheWrapper(
806 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
809 with torch.no_grad():
810 for m in self.modules():
811 if isinstance(m, nn.Embedding):
812 m.weight.normal_(mean=0, std=2e-2)
813 elif isinstance(m, nn.LayerNorm):
817 self.reset_inner_loss()
819 def forward(self, bs):
820 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
822 # To make the code simpler in the Caterpillar layer, we pad
823 # here. It's unclear if/how much it hurts computationaly by
824 # increasing the sequence length for the other layers
826 if self.caterpillar_length > 0:
828 if bs.nb % self.caterpillar_length > 0:
829 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
831 bs = BracketedSequence(
832 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
833 bs.first + self.caterpillar_length,
838 bs = self.embedding(bs)
840 bs = self.readout(bs)
842 if self.caterpillar_length > 0:
843 bs = BracketedSequence(
844 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
845 bs.first - self.caterpillar_length,
852 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
853 # 1s where tokens should be generated. The others are kept
856 def masked_inplace_autoregression(
860 forbidden_tokens=None,
861 deterministic_synthesis=False,
863 input = input_src.to(self.readout.f.weight.device)
864 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
865 to_generate = (ar_mask.sum(0) > 0).nonzero()
866 if to_generate.min() > 0:
868 BracketedSequence(input, 0, to_generate.min(), True)
869 ) # Needed to initialize the model's cache
870 for s in range(to_generate.min(), to_generate.max() + 1):
871 output = self(BracketedSequence(input, s, 1, s == 0)).x
872 logits = output[:, s]
873 if forbidden_tokens is not None:
874 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
875 if deterministic_synthesis:
876 t_next = logits.argmax(1)
878 dist = torch.distributions.categorical.Categorical(logits=logits)
879 t_next = dist.sample()
880 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
882 input_src.copy_(input)
884 def reset_inner_loss(self):
885 for m in self.modules():
886 if m is not self and hasattr(m, "reset_inner_loss"):
889 def get_inner_loss(self):
890 l = torch.tensor([0.0], device=self.readout.f.weight.device)
891 for m in self.modules():
892 if m is not self and hasattr(m, "get_inner_loss"):
893 l += m.get_inner_loss()
896 def record_attention(self, v=True):
897 for m in self.modules():
898 if isinstance(m, QKVAttention):
899 m.record_attention = v
901 def retrieve_attention(self):
903 for m in self.modules():
904 if isinstance(m, QKVAttention):
909 ######################################################################
911 if __name__ == "__main__":
912 print("Basic check.")
919 caterpillar_length=7,
920 caterpillar_height=3,
921 attention_dropout=0.0,
925 x = torch.randn(1, 21 + 2 * 7, 4)
926 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
927 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
928 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
929 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
930 print((y1 - y2).abs().max())
931 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
934 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
936 vocabulary_size = 128
937 x = torch.randint(vocabulary_size, (6, 1024))
940 vocabulary_size=vocabulary_size,
956 # import torchvision.models as models
957 # from torch.profiler import profile, record_function, ProfilerActivity
959 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
960 # with record_function("model_inference"):
964 start_time = time.perf_counter()
966 model(BracketedSequence(x))
967 duration = time.perf_counter() - start_time
971 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
972 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
974 # print("##############################################################")
975 # y2 = torch.randn_like(y1)
976 # for s in range(x.size(1)):
977 # z = model(BracketedSequence(x, s, 1))
978 # y2[:, s : s + 1] = z.slice()
980 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
982 ######################################################################