X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=a27b99e8dd47eb14696257fb1d814c8e33dd49cb;hb=64dc96ddfa84511ba07d1929481e93e864735409;hp=d8fd227f63c39a70dded3c55f3c230c3a9d58862;hpb=3e4af6d54fb3d7bd6794035cb79e30ecdcadeb6f;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index d8fd227..a27b99e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -126,7 +126,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -147,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -190,6 +201,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +332,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,34 +486,61 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, amplitude=None): + if amplitude is None: + amplitude = 1 / math.sqrt(d[-1]) + return nn.Parameter(amplitude * torch.randn(*d)) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_gate_dropout = 0.0 + ###################################################################### + # sup_args + + x = kwargs.get("gate_dropout") + if x is None: + self.proba_gate_dropout = 0.0 + else: + self.proba_gate_dropout = float(x) + + logger(f"self.proba_gate_dropout {self.proba_gate_dropout}") + + x = kwargs.get("default_bg") + if x is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(x) + + logger(f"default_bg {default_bg}") + + ###################################################################### self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, + caterpillar_length, + dim_qk, + ) + self.init_V_rec = randw( + caterpillar_height, + caterpillar_length, + dim_v, + ) def reset_inner_loss(self): self.acc_attention = 0 @@ -520,22 +562,22 @@ class Caterpillar(nn.Module): DV = self.w_V.size(1) DK = self.w_K.size(1) DM = self.w_O.size(1) - CH = self.caterpillar_height - CL = self.caterpillar_length + R = self.caterpillar_height + L = self.caterpillar_length assert ( - t0 >= CL and (t1 - t0) % CL == 0 + t0 >= L and (t1 - t0) % L == 0 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" # We cache values to deal efficiently with auto-regression if bs.init_cache: - self.rec_V = X.new_zeros(N, CH, T, DV) - self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_V = X.new_zeros(N, R, T, DV) + self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :] self.cache_Y = X.new_zeros(N, T, DM) @@ -546,55 +588,83 @@ class Caterpillar(nn.Module): # Compute the recurrent state # This is the Gating sequence that modulates the storing of - # the new key and value in the CH pairs of the current - # stack. There are CH independent gating values, which means + # the new key and value in the R pairs of the current + # stack. There are R independent gating values, which means # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row + # warnings.warn("softmax gating", RuntimeWarning) + + # G = ( + # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] + # ).softmax(dim=2) - G = G / G.sum(1, keepdim=True).clamp(min=1) + ###################################################################### + # The "flashbacks" - # We prepare the arguments for the parallel scan + if self.training and self.proba_gate_dropout > 0.0: + # This is a better implementation of "flashbacks". - A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) + # G is NxHxExT where e is the caterpillar's row. - # We start from cached values, which matters in inference + warnings.warn("gate dropout", RuntimeWarning) - init_rec_V = self.rec_V[:, :, t0 - CL : t0] - init_rec_K = self.rec_K[:, :, t0 - CL : t0] + kill = ( + torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + ).float() - ###################################################################### + alpha = G / (1 - self.proba_gate_dropout) - if self.training and self.proba_gate_dropout > 0.0: - warnings.warn("gate dropout", RuntimeWarning) - epsilon = 0.5 + G = alpha * (1 - kill) - ################################################################# - # Associative scan + def recurrence(G, V, K): + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + + # We prepare the arguments for the parallel scan + + A = 1 - G.sum(1) - # Here there is a trick: Since the stack at position t is - # computed by updating that at position t-CL, the parallel - # scan operates with a period of CL. To do so we split the - # sequence indexing in two axes, the second of size CL, and - # run the parallel scan using the first as the sequence index. + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - A = A.unflatten(2, (-1, CL)) - gated_V = gated_V.unflatten(2, (-1, CL)) - gated_K = gated_K.unflatten(2, (-1, CL)) + # We start from cached values, which matters in inference + + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] + + # Associative scan + + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. + + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) + + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + + next_V = next_V.flatten(2, 3) + next_K = next_K.flatten(2, 3) + + return next_V, next_K + + ################################################################# - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + next_V, next_K = recurrence(G, V, K) - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) - self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + self.rec_V[:, :, t0:t1] = next_V + self.rec_K[:, :, t0:t1] = next_K ###################################################################### # compute the readout @@ -606,14 +676,14 @@ class Caterpillar(nn.Module): # the column in the caterpillar windowed_V = moving_window( - self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) windowed_K = moving_window( - self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) - # We have an attention score for each of the CHxCL values + # We have an attention score for each of the RxL values ar = torch.einsum( "nhtd,nftld->nhtfl", @@ -655,6 +725,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -746,6 +818,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -782,6 +856,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -791,6 +867,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -800,6 +878,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -810,6 +890,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")