X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=c8330128e115efc1c1f3773ada90e6b929559ce7;hb=a3c32b845b6903fd290f2b09d5c53203ff112b79;hp=67c5cfd96ff11a5fd04b88eb6b26a72c90e97ddb;hpb=6d23462ce76c9020dcd7c4bc8a0e7a0fae9b7971;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 67c5cfd..c833012 100755 --- a/mygpt.py +++ b/mygpt.py @@ -502,13 +502,9 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.gate_dropout_proba = args.gate_dropout_proba - self.gate_dropout_sync = args.gate_dropout_sync - self.gate_dropout_replace = args.gate_dropout_replace - ###################################################################### - self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3) + self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) @@ -569,8 +565,6 @@ class Caterpillar(nn.Module): V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) - # V, K = blanket(V), blanket(K) - ###################################################################### # Compute the recurrent state @@ -589,81 +583,30 @@ class Caterpillar(nn.Module): G = G / G.sum(1, keepdim=True).clamp(min=1) - # G_star = (1 - G).log().sum(1, keepdim=True).exp() - ###################################################################### - def recurrence(G, V, K): - # We prepare the arguments for the parallel scan - - A = 1 - G.sum(dim=1) - - gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) - gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - - # We start from cached values, which matters in inference - - init_rec_V = self.rec_V[:, :, t0 - L : t0] - init_rec_K = self.rec_K[:, :, t0 - L : t0] - - # Here there is a trick: Since the stack at position t is - # computed by updating that at position t-L, the parallel - # scan operates with a period of L. To do so we split the - # sequence indexing in two axes, the second of size L, and - # run the parallel scan using the first as the sequence index. - - A = A.unflatten(2, (-1, L)) - gated_V = gated_V.unflatten(2, (-1, L)) - gated_K = gated_K.unflatten(2, (-1, L)) - - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) + A = 1 - G.sum(dim=1) - return next_V, next_K + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - ################################################################# + # We start from cached values, which matters in inference - next_V, next_K = recurrence(G, V, K) + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - if self.training and self.gate_dropout_proba > 0.0: - # G is NxHxRxT where r is the caterpillar's row. + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - warnings.warn("gate dropout", RuntimeWarning) + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - if self.gate_dropout_sync: - shape_kill = (N, 1, 1) - else: - shape_kill = (N, H, R) - - # Pick a point in each of the NxHxR timeline and set this - # entry and the following to 1 - kill = ( - torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices - == 0 - ).cumsum(dim=3) - - # Keep these mask for only some of the NxHxR - kill = kill * ( - torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba - ) - - # The coefficient to keep are the complementary - mask = 1 - kill - - masked_next_V, masked_next_K = recurrence(G * mask, V, K) - - if self.gate_dropout_replace: - next_V = next_V.detach() - next_K = next_K.detach() - - warnings.warn("the rescaling is probably a bad idea", RuntimeWarning) - - next_V = next_V + (masked_next_V - masked_next_V.detach()) / ( - 1 - self.gate_dropout_proba - ) - next_K = next_K + (masked_next_K - masked_next_K.detach()) / ( - 1 - self.gate_dropout_proba - ) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K @@ -710,10 +653,6 @@ class Caterpillar(nn.Module): windowed_V, ).flatten(2) - # Compute the final output - - # Y = blanket(Y) - self.cache_Y[:, t0:t1] = Y @ self.w_O return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) @@ -730,6 +669,7 @@ class QKVAttention(nn.Module): dim_v, nb_heads=1, causal=False, + horizon=None, attention_dropout=0.0, logger=print, args=None, @@ -740,6 +680,7 @@ class QKVAttention(nn.Module): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) self.causal = causal + self.horizon = horizon self.attention_dropout = attention_dropout self.record_attention = False @@ -783,6 +724,17 @@ class QKVAttention(nn.Module): torch.arange(x_q.size(1), device=q.device)[None, None, :, None] < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] ) + + if self.horizon is not None: + self.cache_attzero = torch.logical_or( + self.cache_attzero, + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + >= torch.arange(x_q.size(1), device=q.device)[ + None, None, None, : + ] + + self.horizon, + ) + a = a.masked_fill( self.cache_attzero[ :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb @@ -834,9 +786,10 @@ class MyGPT(nn.Module): "dumbrec", "kvrec", "caterpillar", + "attcat", }, f"Unknown attention operator {attention_layer}." - if attention_layer == "caterpillar": + if attention_layer == "caterpillar" or attention_layer == "attcat": assert nb_lines % caterpillar_height == 0 self.caterpillar_length = nb_lines // caterpillar_height self.caterpillar_height = caterpillar_height @@ -855,59 +808,99 @@ class MyGPT(nn.Module): def attlayer(): if attention_layer == "mha": - return QKVAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - causal=causal, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "dumbrec": - return DumbRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + DumbRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "kvrec": - return KVRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + KVRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "caterpillar": - return Caterpillar( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - caterpillar_length=self.caterpillar_length, - caterpillar_height=self.caterpillar_height, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ) + elif attention_layer == "attcat": + return nn.Sequential( + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + horizon=self.caterpillar_length, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), ) else: raise ValueError(f"Unknown attention type {attention_layer}.") for b in range(nb_blocks): trunk_blocks += [ - WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), - attlayer(), - ), + attlayer(), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)),