X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=492a9bb96872e93f99ea9d9609ba64fe557c57fa;hb=e3d5af800ccd197580265709c4499bf281beecb8;hp=b885e218be6704cd86afed18966e5609e9873369;hpb=42831bd654d030b71bca88578d041279018f836c;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index b885e21..492a9bb 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops @@ -124,7 +126,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -145,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -188,6 +201,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -317,6 +332,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -469,35 +486,61 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, amplitude=None): + if amplitude is None: + amplitude = 1 / math.sqrt(d[-1]) + return nn.Parameter(amplitude * torch.randn(*d)) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_flashback = 0.0 - self.proba_gate_dropout = 0.0 + ###################################################################### + # sup_args + + x = kwargs.get("gate_dropout") + if x is None: + self.proba_gate_dropout = 0.0 + else: + self.proba_gate_dropout = float(x) + + logger(f"self.proba_gate_dropout {self.proba_gate_dropout}") + + x = kwargs.get("default_bg") + if x is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(x) + + logger(f"default_bg {default_bg}") + + ###################################################################### self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, + caterpillar_length, + dim_qk, + ) + self.init_V_rec = randw( + caterpillar_height, + caterpillar_length, + dim_v, + ) def reset_inner_loss(self): self.acc_attention = 0 @@ -519,113 +562,108 @@ class Caterpillar(nn.Module): DV = self.w_V.size(1) DK = self.w_K.size(1) DM = self.w_O.size(1) - CH = self.caterpillar_height - CL = self.caterpillar_length + R = self.caterpillar_height + L = self.caterpillar_length assert ( - t0 >= CL and (t1 - t0) % CL == 0 + t0 >= L and (t1 - t0) % L == 0 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" # We cache values to deal efficiently with auto-regression if bs.init_cache: - self.rec_V = X.new_zeros(N, CH, T, DV) - self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_V = X.new_zeros(N, R, T, DV) + self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :] self.cache_Y = X.new_zeros(N, T, DM) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### # Compute the recurrent state # This is the Gating sequence that modulates the storing of - # the new key and value in the CH pairs of the current - # stack. The CH gating values are independent, which means - # that the current K/V could be stored in multiple pairs of the + # the new key and value in the R pairs of the current + # stack. There are R independent gating values, which means + # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - if self.training and self.proba_gate_dropout > 0.0: - warnings.warn("gate droupout", RuntimeWarning) - epsilon = 0.5 + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - # That was a bad idea - # G = F.dropout(G, self.attention_dropout, self.training) + G = G / G.sum(1, keepdim=True).clamp(min=1) - V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) - K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### - # We prepare the arguments for the parallel scan + def recurrence(G, V, K): + # We prepare the arguments for the parallel scan - # Clip the gating - warnings.warn("gating clipping", RuntimeWarning) - G = G / G.sum(1, keepdim=True).clamp(min=1) + A = 1 - G.sum(1) - A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - init_rec_V = self.rec_V[:, :, t0 - CL : t0] - init_rec_K = self.rec_K[:, :, t0 - CL : t0] + # We start from cached values, which matters in inference - # Here there is a trick: Since the stack at time t is computed - # by updating that at time t-L, the parallel scan operates - # with a period of L. To do so we split the time indexing in - # two axes, the second of size CL, and run the parallel scan - # using the other as the sequence index. + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - A = A.unflatten(2, (-1, CL)) - gated_V = gated_V.unflatten(2, (-1, CL)) - gated_K = gated_K.unflatten(2, (-1, CL)) + # Associative scan - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - # Put back the sequence index + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) - self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - if self.training and self.proba_flashback > 0.0: - warnings.warn("flash back", RuntimeWarning) - # This piece of code makes the assumption that there is - # nothing informative before t0, otherwise we'd have to - # implement a cache for V and K too. This should not be - # too much of a problem since this is used only during - # train, where full sequence are available + next_V = next_V.flatten(2, 3) + next_K = next_K.flatten(2, 3) - n = torch.arange(N, device=X.device)[:, None, None, None] - t = torch.arange(t0, t1, device=X.device)[None, None, :, None] - dv = torch.arange(DV, device=X.device)[None, None, None, :] - dk = torch.arange(DK, device=X.device)[None, None, None, :] + return next_V, next_K - u = ( - torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL - ) * CL + ################################################################# - src_time = t - u - t0 - src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device) + next_V, next_K = recurrence(G, V, K) - mask = ( - torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback - ).long() + if self.training and self.proba_gate_dropout > 0.0: + # G is NxHxRxT where r is the caterpillar's row. - self.rec_V[:, :, t0:t1] = ( - mask * V[n, src_head, src_time, dv] - + (1 - mask) * self.rec_V[:, :, t0:t1] - ) + warnings.warn("gate dropout", RuntimeWarning) + + kill = ( + torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + ).float() + + mask = 1 - kill - self.rec_K[:, :, t0:t1] = ( - mask * K[n, src_head, src_time, dk] - + (1 - mask) * self.rec_K[:, :, t0:t1] + masked_next_V, masked_next_K = recurrence(G * mask, V, K) + + next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.proba_gate_dropout + ) + next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.proba_gate_dropout ) + self.rec_V[:, :, t0:t1] = next_V + self.rec_K[:, :, t0:t1] = next_K + ###################################################################### # compute the readout @@ -636,14 +674,14 @@ class Caterpillar(nn.Module): # the column in the caterpillar windowed_V = moving_window( - self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) windowed_K = moving_window( - self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) - # We have an attention score for each of the CHxCL values + # We have an attention score for each of the RxL values ar = torch.einsum( "nhtd,nftld->nhtfl", @@ -685,6 +723,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -772,11 +812,12 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -813,34 +854,42 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")