X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=633ad642c19a3045064ef858c0ee494a7c733425;hb=6e87fe0cb8bd8a0042bbf7b2ede9d8ed0372fb6b;hp=7f0fb9b6fa506a5136ff4e98c8b9f5a4087420ee;hpb=8a32cb4548bb48ef68adb4df9372fe5f7a80b67c;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 7f0fb9b..633ad64 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops @@ -37,7 +39,7 @@ import ffutils # 1 for the successive tokens. # # Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 +# resetted when init_cache is True class BracketedSequence: @@ -457,87 +459,6 @@ def moving_window(x, dim, win_dim, win_size): ############################## -# This is one order of magnitude more complicated than I expected, not -# elegant, slow, hopefully not buggy - - -def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device): - # starting flash backs - fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long() - fb_start[:, :, -CL:] = 0 - fb_start[:, :, :CL] = 0 - - # Remove series longer than CL - fb_body = fb_start.clone() - fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)] - fb_body = fb_body.cumsum(dim=2) - fb_start = fb_start * (fb_body == 1) - - # Set a origin source time (starting time of the chunck to copy - # here) We set it as the current time minus a multiple of CL to be - # consistent with the "rolling" caterpillar - t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :] - src_time = fb_start * ( - t - - CL - * ( - 1 - + ( - torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1) - ).long() - ) - ) - src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL] - src_time = src_time.cumsum(dim=2) - - src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device) - src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL] - src_head = src_head.cumsum(dim=2) - - # combine - src_delta = fb_start.clone() - src_delta[:, :, CL:] -= fb_start[:, :, :-CL] - src_delta = src_delta.cumsum(dim=2) - src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL] - src_time += src_delta.cumsum(dim=2) - 1 - - return src_time, src_head - - -def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba): - N, H, CH = V.size(0), V.size(1), rec_V.size(1) - - fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device) - - fbt_V = fbt[:, :, :, None].expand_as(rec_V[:, :, t0:t1]) - fbh_V = fbh[:, :, :, None].expand_as(rec_V[:, :, t0:t1]) - t = fbt_V.clamp(min=0) - n = torch.arange(V.size(0), device=V.device)[:, None, None, None].expand_as( - rec_V[:, :, t0:t1] - ) - d = torch.arange(V.size(3), device=V.device)[None, None, None, :].expand_as( - rec_V[:, :, t0:t1] - ) - q = V[:, :, t0:t1][n, fbh_V, t, d] - rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0) - - fbt_K = fbt[:, :, :, None].expand_as(rec_K[:, :, t0:t1]) - fbh_K = fbh[:, :, :, None].expand_as(rec_K[:, :, t0:t1]) - t = fbt_K.clamp(min=0) - n = torch.arange(K.size(0), device=K.device)[:, None, None, None].expand_as( - rec_K[:, :, t0:t1] - ) - d = torch.arange(K.size(3), device=K.device)[None, None, None, :].expand_as( - rec_K[:, :, t0:t1] - ) - q = K[:, :, t0:t1][n, fbh_K, t, d] - rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0) - - # print("SANITY", (fbt_K >=0).float().sum()/fbt_K.numel()) - - -###################################################################### - class Caterpillar(nn.Module): def __init__( @@ -555,13 +476,17 @@ class Caterpillar(nn.Module): warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, amplitude=None): + if amplitude is None: + amplitude = 1 / math.sqrt(d[-1]) + return nn.Parameter(amplitude * torch.randn(*d)) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout + self.proba_gate_dropout = 0.0 + self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter( torch.full( @@ -574,8 +499,16 @@ class Caterpillar(nn.Module): self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, + caterpillar_length, + dim_qk, + ) + self.init_V_rec = randw( + caterpillar_height, + caterpillar_length, + dim_v, + ) def reset_inner_loss(self): self.acc_attention = 0 @@ -593,78 +526,139 @@ class Caterpillar(nn.Module): N = bs.x.size(0) T = bs.x.size(1) + H = self.w_V.size(0) DV = self.w_V.size(1) DK = self.w_K.size(1) DM = self.w_O.size(1) - CH = self.caterpillar_height - CL = self.caterpillar_length + R = self.caterpillar_height + L = self.caterpillar_length assert ( - t0 >= CL and (t1 - t0) % CL == 0 + t0 >= L and (t1 - t0) % L == 0 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" # We cache values to deal efficiently with auto-regression if bs.init_cache: - self.rec_V = X.new_zeros(N, CH, T, DV) - self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_V = X.new_zeros(N, R, T, DV) + self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :] self.cache_Y = X.new_zeros(N, T, DM) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### # Compute the recurrent state # This is the Gating sequence that modulates the storing of - # the new key and value in the CH pairs of the current - # stack. The CH gating values are independent, which means - # that the current K/V could be stored in multiple pairs of the + # the new key and value in the R pairs of the current + # stack. There are R independent gating values, which means + # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # That bas a bad idea - # G = F.dropout(G, self.attention_dropout, self.training) + ###################################################################### + # Roll the gating indexes - V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) - K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + warnings.warn("rotating barrel", RuntimeWarning) + + # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}") + + n_barrel = torch.arange(N, device=G.device)[:, None, None, None] + h_barrel = torch.arange(H, device=G.device)[None, :, None, None] + r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + r_barrel = (r_barrel + (t_barrel + t0) // L) % R + + # GG = G.gather(dim=2,index=r_barrel) + G = G[n_barrel, h_barrel, r_barrel, t_barrel] + + # print("SANITY", (GG-G).abs()) + # exit(0) + + ###################################################################### + # The "flashbacks" + + if self.training and self.proba_gate_dropout > 0.0: + # This is a better implementation of "flashbacks". + + # G is NxHxExT where e is the caterpillar's row. + + warnings.warn("gate dropout", RuntimeWarning) + epsilon = 0.5 + + dropout_head = ( + (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) + .expand_as(G) + .float() + ) + + dropout_tail = dropout_head.cumsum(dim=3) - dropout_head + + dropout_active = ( + torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout + ).long() + + dropout_head *= dropout_active + dropout_tail *= dropout_active + + G = ( + G + + dropout_head * (1 - epsilon - G.detach()) + - dropout_tail * G.detach() + ) + + ###################################################################### # We prepare the arguments for the parallel scan + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) - init_rec_V = self.rec_V[:, :, t0 - CL : t0] - init_rec_K = self.rec_K[:, :, t0 - CL : t0] + # warnings.warn("harmonic recurrence", RuntimeWarning) + # har = torch.arange(t0, t1, device = G.device).float() + 1 + # A = har / (har + 1) + # G = G / har + + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) + + # We start from cached values, which matters in inference - # Here there is a trick: Since the stack at time t is computed - # by updating that at time t-L, the parallel scan operates - # with a period of L. To do so we split the time indexing in - # two axes, the second of size CL, and run the parallel scan - # using the other as the sequence index. + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - A = A.unflatten(2, (-1, CL)) - gated_V = gated_V.unflatten(2, (-1, CL)) - gated_K = gated_K.unflatten(2, (-1, CL)) + ################################################################# + # Associative scan + + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. + + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - # Put back the sequence index - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) - warnings.warn("flash back", RuntimeWarning) - if self.training: - insert_flash_back(self.rec_V, V, self.rec_K, K, t0, t1, CL, proba=1e-2 / CL) - ###################################################################### # compute the readout @@ -675,14 +669,14 @@ class Caterpillar(nn.Module): # the column in the caterpillar windowed_V = moving_window( - self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) windowed_K = moving_window( - self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) - # We have an attention score for each of the CHxCL values + # We have an attention score for each of the RxL values ar = torch.einsum( "nhtd,nftld->nhtfl", @@ -811,7 +805,6 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, @@ -819,7 +812,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0 @@ -852,7 +850,7 @@ class MyGPT(nn.Module): return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -861,7 +859,7 @@ class MyGPT(nn.Module): return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -870,7 +868,7 @@ class MyGPT(nn.Module): return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height,