X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=633ad642c19a3045064ef858c0ee494a7c733425;hb=6e87fe0cb8bd8a0042bbf7b2ede9d8ed0372fb6b;hp=eda8685b8a58653d6739cf4016f412528db378ce;hpb=037adb139441f40078421cd40f6aad1748c2724d;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index eda8685..633ad64 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops @@ -37,7 +39,7 @@ import ffutils # 1 for the successive tokens. # # Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 +# resetted when init_cache is True class BracketedSequence: @@ -474,15 +476,16 @@ class Caterpillar(nn.Module): warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, amplitude=None): + if amplitude is None: + amplitude = 1 / math.sqrt(d[-1]) + return nn.Parameter(amplitude * torch.randn(*d)) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - warnings.warn("flash back", RuntimeWarning) - self.proba_flashback = 0.1 + self.proba_gate_dropout = 0.0 self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter( @@ -496,8 +499,16 @@ class Caterpillar(nn.Module): self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, + caterpillar_length, + dim_qk, + ) + self.init_V_rec = randw( + caterpillar_height, + caterpillar_length, + dim_v, + ) def reset_inner_loss(self): self.acc_attention = 0 @@ -519,107 +530,134 @@ class Caterpillar(nn.Module): DV = self.w_V.size(1) DK = self.w_K.size(1) DM = self.w_O.size(1) - CH = self.caterpillar_height - CL = self.caterpillar_length + R = self.caterpillar_height + L = self.caterpillar_length assert ( - t0 >= CL and (t1 - t0) % CL == 0 + t0 >= L and (t1 - t0) % L == 0 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" # We cache values to deal efficiently with auto-regression if bs.init_cache: - self.rec_V = X.new_zeros(N, CH, T, DV) - self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_V = X.new_zeros(N, R, T, DV) + self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :] self.cache_Y = X.new_zeros(N, T, DM) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### # Compute the recurrent state # This is the Gating sequence that modulates the storing of - # the new key and value in the CH pairs of the current - # stack. The CH gating values are independent, which means - # that the current K/V could be stored in multiple pairs of the + # the new key and value in the R pairs of the current + # stack. There are R independent gating values, which means + # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # That bas a bad idea - # G = F.dropout(G, self.attention_dropout, self.training) - - V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) - K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### + # Roll the gating indexes - # We prepare the arguments for the parallel scan + warnings.warn("rotating barrel", RuntimeWarning) - A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) + # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}") - init_rec_V = self.rec_V[:, :, t0 - CL : t0] - init_rec_K = self.rec_K[:, :, t0 - CL : t0] + n_barrel = torch.arange(N, device=G.device)[:, None, None, None] + h_barrel = torch.arange(H, device=G.device)[None, :, None, None] + r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + r_barrel = (r_barrel + (t_barrel + t0) // L) % R - # Here there is a trick: Since the stack at time t is computed - # by updating that at time t-L, the parallel scan operates - # with a period of L. To do so we split the time indexing in - # two axes, the second of size CL, and run the parallel scan - # using the other as the sequence index. + # GG = G.gather(dim=2,index=r_barrel) + G = G[n_barrel, h_barrel, r_barrel, t_barrel] - A = A.unflatten(2, (-1, CL)) - gated_V = gated_V.unflatten(2, (-1, CL)) - gated_K = gated_K.unflatten(2, (-1, CL)) + # print("SANITY", (GG-G).abs()) + # exit(0) - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + ###################################################################### + # The "flashbacks" - # Put back the sequence index + if self.training and self.proba_gate_dropout > 0.0: + # This is a better implementation of "flashbacks". - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) - self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + # G is NxHxExT where e is the caterpillar's row. - if self.training and self.proba_flashback > 0.0: - # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,) + warnings.warn("gate dropout", RuntimeWarning) + epsilon = 0.5 - # This piece of code makes the assumption that there is - # nothing informative before t0, otherwise we'd have to - # implement a cache for V and K too. This should not be - # too much of a problem since this is used only during - # train, where full sequence are available + dropout_head = ( + (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) + .expand_as(G) + .float() + ) - n = torch.arange(N, device=X.device)[:, None, None, None] - t = torch.arange(t0, t1, device=X.device)[None, None, :, None] - dv = torch.arange(DV, device=X.device)[None, None, None, :] - dk = torch.arange(DK, device=X.device)[None, None, None, :] + dropout_tail = dropout_head.cumsum(dim=3) - dropout_head - u = ( - torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL - ) * CL + dropout_active = ( + torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout + ).long() - src_time = t - u - t0 - src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device) + dropout_head *= dropout_active + dropout_tail *= dropout_active - mask_V = ( - torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback - ).long() - self.rec_V[:, :, t0:t1] = ( - mask_V * V[n, src_head, src_time, dv] - + (1 - mask_V) * self.rec_V[:, :, t0:t1] + G = ( + G + + dropout_head * (1 - epsilon - G.detach()) + - dropout_tail * G.detach() ) - mask_K = ( - torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback - ).long() - self.rec_K[:, :, t0:t1] = ( - mask_K * K[n, src_head, src_time, dk] - + (1 - mask_K) * self.rec_K[:, :, t0:t1] - ) + ###################################################################### + + # We prepare the arguments for the parallel scan + + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + + A = 1 - G.sum(1) + + # warnings.warn("harmonic recurrence", RuntimeWarning) + # har = torch.arange(t0, t1, device = G.device).float() + 1 + # A = har / (har + 1) + # G = G / har + + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) + + # We start from cached values, which matters in inference + + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] + + ################################################################# + # Associative scan + + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. + + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) + + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + + self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) + self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) ###################################################################### # compute the readout @@ -631,14 +669,14 @@ class Caterpillar(nn.Module): # the column in the caterpillar windowed_V = moving_window( - self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) windowed_K = moving_window( - self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) - # We have an attention score for each of the CHxCL values + # We have an attention score for each of the RxL values ar = torch.einsum( "nhtd,nftld->nhtfl", @@ -767,7 +805,6 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, @@ -775,7 +812,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0 @@ -808,7 +850,7 @@ class MyGPT(nn.Module): return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -817,7 +859,7 @@ class MyGPT(nn.Module): return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -826,7 +868,7 @@ class MyGPT(nn.Module): return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height,