X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=7c8e9f4c894ad332e808d07f008ac4c569046bd1;hb=f0ea1f2375fa3a0be38970a58185cddee97dccef;hp=5ea927e09b7b1835bc1222a6ddf5329868d26da9;hpb=3c5ce93138700c33a055f83ac1a46efb2975e28a;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 5ea927e..7c8e9f4 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops @@ -37,7 +39,7 @@ import ffutils # 1 for the successive tokens. # # Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 +# resetted when init_cache is True class BracketedSequence: @@ -457,77 +459,6 @@ def moving_window(x, dim, win_dim, win_size): ############################## -# This is one order of magnitude more complicated than I expected, not -# elegant, slow, hopefully not buggy - - -def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device): - # starting flash backs - fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long() - fb_start[:, :, -CL:] = 0 - fb_start[:, :, :CL] = 0 - - # Remove series longer than CL - fb_body = fb_start.clone() - fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)] - fb_body = fb_body.cumsum(dim=2) - fb_start = fb_start * (fb_body == 1) - - # Set a origin source time (starting time of the chunck to copy - # here) We set it as the current time minus a multiple of CL to be - # consistent with the "rolling" caterpillar - t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :] - src_time = fb_start * ( - t - - CL - * ( - 1 - + ( - torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1) - ).long() - ) - ) - src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL] - src_time = src_time.cumsum(dim=2) - - src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device) - src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL] - src_head = src_head.cumsum(dim=2) - - # combine - src_delta = fb_start.clone() - src_delta[:, :, CL:] -= fb_start[:, :, :-CL] - src_delta = src_delta.cumsum(dim=2) - src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL] - src_time += src_delta.cumsum(dim=2) - 1 - - return src_time, src_head - - -def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba): - N, H, CH = V.size(0), V.size(1), rec_V.size(1) - - fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device) - - fbt_V = fbt[:, :, :, None] - fbh_V = fbh[:, :, :, None] - t = fbt_V.clamp(min=0) - n = torch.arange(V.size(0), device=V.device)[:, None, None, None] - d = torch.arange(V.size(3), device=V.device)[None, None, None, :] - q = V[:, :, t0:t1][n, fbh_V, t, d] - rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0) - - fbt_K = fbt[:, :, :, None] - fbh_K = fbh[:, :, :, None] - t = fbt_K.clamp(min=0) - n = torch.arange(K.size(0), device=K.device)[:, None, None, None] - d = torch.arange(K.size(3), device=K.device)[None, None, None, :] - q = K[:, :, t0:t1][n, fbh_K, t, d] - rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0) - - -###################################################################### - class Caterpillar(nn.Module): def __init__( @@ -545,17 +476,18 @@ class Caterpillar(nn.Module): warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, amplitude=None): + if amplitude is None: + amplitude = 1 / math.sqrt(d[-1]) + return nn.Parameter(amplitude * torch.randn(*d)) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - warnings.warn("flash back", RuntimeWarning) - self.proba_flashback = 0.1 + self.proba_gate_dropout = 0.0 - self.w_G = randw(nb_heads, caterpillar_height, dim_model) + self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5) self.b_G = nn.Parameter( torch.full( (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) @@ -567,8 +499,12 @@ class Caterpillar(nn.Module): self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5 + ) + self.init_V_rec = randw( + caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5 + ) def reset_inner_loss(self): self.acc_attention = 0 @@ -609,39 +545,86 @@ class Caterpillar(nn.Module): self.cache_Y = X.new_zeros(N, T, DM) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + ###################################################################### # Compute the recurrent state # This is the Gating sequence that modulates the storing of # the new key and value in the CH pairs of the current - # stack. The CH gating values are independent, which means - # that the current K/V could be stored in multiple pairs of the + # stack. There are CH independent gating values, which means + # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # That bas a bad idea - # G = F.dropout(G, self.attention_dropout, self.training) + ###################################################################### + # The "flashbacks" - V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) - K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + if self.training and self.proba_gate_dropout > 0.0: + # This is a better implementation of "flashbacks". + + # G is NxHxExT where e is the caterpillar's row. + + warnings.warn("gate dropout", RuntimeWarning) + epsilon = 0.5 + + dropout_start = ( + ( + torch.rand(G.size(), device=G.device) + .flatten(2, 3) + .sort(dim=2) + .indices + == 0 + ) + .unflatten(2, (CH, t1 - t0)) + .float() + ) + + dropout_tail = dropout_start.cumsum(dim=3) - dropout_start + + dropout_active = ( + torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout + ).long() + + dropout_start *= dropout_active + dropout_tail *= dropout_active + + G = ( + G + + dropout_start * (1 - epsilon - G.detach()) + - dropout_tail * G.detach() + ) + + ###################################################################### # We prepare the arguments for the parallel scan + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) + + # We start from cached values, which matters in inference init_rec_V = self.rec_V[:, :, t0 - CL : t0] init_rec_K = self.rec_K[:, :, t0 - CL : t0] - # Here there is a trick: Since the stack at time t is computed - # by updating that at time t-L, the parallel scan operates - # with a period of L. To do so we split the time indexing in - # two axes, the second of size CL, and run the parallel scan - # using the other as the sequence index. + ################################################################# + # Associative scan + + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-CL, the parallel + # scan operates with a period of CL. To do so we split the + # sequence indexing in two axes, the second of size CL, and + # run the parallel scan using the first as the sequence index. A = A.unflatten(2, (-1, CL)) gated_V = gated_V.unflatten(2, (-1, CL)) @@ -650,48 +633,9 @@ class Caterpillar(nn.Module): next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - # Put back the sequence index - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) - if self.training and self.proba_flashback: - # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,) - - # This piece of code makes the assumption that there is - # nothing informative before t0, otherwise we'd have to - # implement a cache for V and K too. This should not be - # too much of a problem since this is used only during - # train, where full sequence are available - - n = torch.arange(N, device=X.device)[:, None, None, None] - t = torch.arange(t0, t1, device=X.device)[None, None, :, None] - dv = torch.arange(DV, device=X.device)[None, None, None, :] - dk = torch.arange(DK, device=X.device)[None, None, None, :] - - u = ( - torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL - ) * CL - - src_time = t - u - t0 - src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device) - - mask_V = ( - torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback - ).long() - self.rec_V[:, :, t0:t1] = ( - mask_V * V[n, src_head, src_time, dv] - + (1 - mask_V) * self.rec_V[:, :, t0:t1] - ) - - mask_K = ( - torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback - ).long() - self.rec_K[:, :, t0:t1] = ( - mask_K * K[n, src_head, src_time, dk] - + (1 - mask_K) * self.rec_K[:, :, t0:t1] - ) - ###################################################################### # compute the readout @@ -838,7 +782,6 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, @@ -846,7 +789,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0 @@ -879,7 +827,7 @@ class MyGPT(nn.Module): return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -888,7 +836,7 @@ class MyGPT(nn.Module): return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -897,7 +845,7 @@ class MyGPT(nn.Module): return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height,