X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=b137cdbca1cd085ee8dec0185c514118e928b5d3;hb=9112db2ed7d8c262c4ef8298cf6637515675f967;hp=099847c95d9404d477b069d8cdf78a62304b3784;hpb=2434c00a82ebb0b23f45d891cc9f80324e3200bd;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 099847c..b137cdb 100755 --- a/mygpt.py +++ b/mygpt.py @@ -126,7 +126,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -147,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -191,7 +202,7 @@ class DumbRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -322,7 +333,7 @@ class KVRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -476,7 +487,7 @@ class Caterpillar(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -491,16 +502,13 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_gate_dropout = 0.0 + self.gate_dropout_proba = args.gate_dropout_proba + self.gate_dropout_sync = args.gate_dropout_sync + self.gate_dropout_replace = args.gate_dropout_replace - default_bg = kwargs.get("default_bg") - if default_bg is None: - default_bg = -math.log(caterpillar_height - 1) - else: - default_bg = float(default_bg) - - logger(f"default_bg {default_bg}") + ###################################################################### + default_bg = -math.log(caterpillar_height - 1) self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) @@ -520,14 +528,14 @@ class Caterpillar(nn.Module): dim_v, ) - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 + # def reset_inner_loss(self): + # self.acc_attention = 0 + # self.acc_nb = 0 - def get_inner_loss(self): - # warnings.warn("l2 regularization", RuntimeWarning) - # return (self.acc_attention / self.acc_nb).pow(2).sum() - return torch.tensor([0], device=self.w_Q.device) + # def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_Q.device) def forward(self, bs): # Dimensions to make the source a bit clearer, that's needed @@ -554,8 +562,8 @@ class Caterpillar(nn.Module): self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :] self.cache_Y = X.new_zeros(N, T, DM) @@ -581,92 +589,87 @@ class Caterpillar(nn.Module): G = G / G.sum(1, keepdim=True).clamp(min=1) ###################################################################### - # Roll the gating indexes - - # warnings.warn("rotating barrel", RuntimeWarning) - - # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - # r_barrel = (r_barrel + (t_barrel + t0) // L) % R - # G = G.gather(dim=2, index=r_barrel.expand_as(G)) - - ###################################################################### - # The "flashbacks" - if self.training and self.proba_gate_dropout > 0.0: - # This is a better implementation of "flashbacks". + def recurrence(G, V, K): + # We prepare the arguments for the parallel scan - # G is NxHxExT where e is the caterpillar's row. + A = 1 - G.sum(1) - warnings.warn("gate dropout", RuntimeWarning) - epsilon = 0.5 + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - dropout_head = ( - (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) - .expand_as(G) - .float() - ) + # We start from cached values, which matters in inference - dropout_tail = dropout_head.cumsum(dim=3) - dropout_head + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - dropout_active = ( - torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout - ).long() + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - dropout_head *= dropout_active - dropout_tail *= dropout_active + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - G = ( - G - + dropout_head * (1 - epsilon - G.detach()) - - dropout_tail * G.detach() - ) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) - ###################################################################### - - # We prepare the arguments for the parallel scan + return next_V, next_K - A = 1 - G.sum(1) + ################################################################# - # warnings.warn("harmonic recurrence", RuntimeWarning) - # har = torch.arange(t0, t1, device = G.device).float() + 1 - # A = har / (har + 1) - # G = G / har + next_V, next_K = recurrence(G, V, K) - gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) - gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) + if self.training and self.gate_dropout_proba > 0.0: + # G is NxHxRxT where r is the caterpillar's row. - # We start from cached values, which matters in inference + warnings.warn("gate dropout", RuntimeWarning) - init_rec_V = self.rec_V[:, :, t0 - L : t0] - init_rec_K = self.rec_K[:, :, t0 - L : t0] + if self.gate_dropout_sync: + shape_kill = (N, 1, 1) + else: + shape_kill = (N, H, R) + + # Pick a point in each of the NxHxR timeline and set this + # entry and the following to 1 + kill = ( + torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices + == 0 + ).cumsum(dim=3) + + # Keep these mask for only some of the NxHxR + kill = kill * ( + torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba + ) - ################################################################# - # Associative scan + # The coefficient to keep are the complementary + mask = 1 - kill - # Here there is a trick: Since the stack at position t is - # computed by updating that at position t-L, the parallel - # scan operates with a period of L. To do so we split the - # sequence indexing in two axes, the second of size L, and - # run the parallel scan using the first as the sequence index. + masked_next_V, masked_next_K = recurrence(G * mask, V, K) - A = A.unflatten(2, (-1, L)) - gated_V = gated_V.unflatten(2, (-1, L)) - gated_K = gated_K.unflatten(2, (-1, L)) + if self.gate_dropout_replace: + next_V = next_V.detach() + next_K = next_K.detach() - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + next_V = next_V + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.gate_dropout_proba + ) + next_K = next_K + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.gate_dropout_proba + ) - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) - self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + self.rec_V[:, :, t0:t1] = next_V + self.rec_K[:, :, t0:t1] = next_K ###################################################################### # compute the readout Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - # We build tensors NxHxTxFxL where N is the sample index, H - # the head, T the time, F the row in the caterpillar, and L + # We build tensors NxHxTxRxL where N is the sample index, H + # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar windowed_V = moving_window( @@ -680,7 +683,7 @@ class Caterpillar(nn.Module): # We have an attention score for each of the RxL values ar = torch.einsum( - "nhtd,nftld->nhtfl", + "nhtd,nrtld->nhtrl", Q, windowed_K, ) / math.sqrt(DK) @@ -720,7 +723,7 @@ class QKVAttention(nn.Module): causal=False, attention_dropout=0.0, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -813,7 +816,7 @@ class MyGPT(nn.Module): len_max=1e5, attention_layer="kvrec", logger=print, - **kwargs, + args=None, ): super().__init__() @@ -851,7 +854,7 @@ class MyGPT(nn.Module): causal=causal, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "dumbrec": return DumbRec( @@ -862,7 +865,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "kvrec": return KVRec( @@ -873,7 +876,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -885,7 +888,7 @@ class MyGPT(nn.Module): caterpillar_height=self.caterpillar_height, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")