Update.
[mygptrnn.git] / mygpt.py
index 185df38..099847c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -190,6 +190,8 @@ class DumbRec(nn.Module):
         nb_lines,
         attention_dropout=0.0,
         len_max=1e5,
+        logger=print,
+        **kwargs,
     ):
         super().__init__()
 
@@ -319,6 +321,8 @@ class KVRec(nn.Module):
         nb_lines,
         attention_dropout=0.0,
         len_max=1e5,
+        logger=print,
+        **kwargs,
     ):
         super().__init__()
 
@@ -471,6 +475,8 @@ class Caterpillar(nn.Module):
         caterpillar_height,
         attention_dropout=0.0,
         len_max=1e5,
+        logger=print,
+        **kwargs,
     ):
         super().__init__()
 
@@ -487,12 +493,16 @@ class Caterpillar(nn.Module):
 
         self.proba_gate_dropout = 0.0
 
+        default_bg = kwargs.get("default_bg")
+        if default_bg is None:
+            default_bg = -math.log(caterpillar_height - 1)
+        else:
+            default_bg = float(default_bg)
+
+        logger(f"default_bg {default_bg}")
+
         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
-        self.b_G = nn.Parameter(
-            torch.full(
-                (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
-            )
-        )
+        self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
 
         self.w_K = randw(nb_heads, dim_qk, dim_model)
         self.w_V = randw(nb_heads, dim_v, dim_model)
@@ -530,22 +540,22 @@ class Caterpillar(nn.Module):
         DV = self.w_V.size(1)
         DK = self.w_K.size(1)
         DM = self.w_O.size(1)
-        CH = self.caterpillar_height
-        CL = self.caterpillar_length
+        R = self.caterpillar_height
+        L = self.caterpillar_length
 
         assert (
-            t0 >= CL and (t1 - t0) % CL == 0
+            t0 >= L and (t1 - t0) % L == 0
         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
 
         # We cache values to deal efficiently with auto-regression
 
         if bs.init_cache:
-            self.rec_V = X.new_zeros(N, CH, T, DV)
-            self.rec_K = X.new_zeros(N, CH, T, DK)
+            self.rec_V = X.new_zeros(N, R, T, DV)
+            self.rec_K = X.new_zeros(N, R, T, DK)
             # We start the recurrent sequences with optimizable
             # initial values. No idea if it helps.
-            self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
-            self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
+            self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
+            self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
 
             self.cache_Y = X.new_zeros(N, T, DM)
 
@@ -556,8 +566,8 @@ class Caterpillar(nn.Module):
         # Compute the recurrent state
 
         # This is the Gating sequence that modulates the storing of
-        # the new key and value in the CH pairs of the current
-        # stack. There are CH independent gating values, which means
+        # the new key and value in the R pairs of the current
+        # stack. There are R independent gating values, which means
         # that the current K/V may be stored in multiple pairs of the
         # recurrent state, or not at all.
 
@@ -565,6 +575,21 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        # Clip the gating to avoid values greater than 1 when several
+        # heads hit the same row
+
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
+
+        ######################################################################
+        # Roll the gating indexes
+
+        # warnings.warn("rotating barrel", RuntimeWarning)
+
+        # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
+        # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
+        # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
+        # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+
         ######################################################################
         # The "flashbacks"
 
@@ -593,7 +618,7 @@ class Caterpillar(nn.Module):
 
             G = (
                 G
-                + dropout_head * (1 - epsilon - G.detach())
+                + dropout_head * (1 - epsilon - G.detach())
                 - dropout_tail * G.detach()
             )
 
@@ -601,32 +626,33 @@ class Caterpillar(nn.Module):
 
         # We prepare the arguments for the parallel scan
 
-        # Clip the gating to avoid values greater than 1 when several
-        # heads hit the same row
+        A = 1 - G.sum(1)
 
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
+        # warnings.warn("harmonic recurrence", RuntimeWarning)
+        # har = torch.arange(t0, t1, device = G.device).float() + 1
+        # A = har / (har + 1)
+        # G = G / har
 
-        A = 1 - G.sum(1)
         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
         # We start from cached values, which matters in inference
 
-        init_rec_V = self.rec_V[:, :, t0 - CL : t0]
-        init_rec_K = self.rec_K[:, :, t0 - CL : t0]
+        init_rec_V = self.rec_V[:, :, t0 - L : t0]
+        init_rec_K = self.rec_K[:, :, t0 - L : t0]
 
         #################################################################
         # Associative scan
 
         # Here there is a trick: Since the stack at position t is
-        # computed by updating that at position t-CL, the parallel
-        # scan operates with a period of CL. To do so we split the
-        # sequence indexing in two axes, the second of size CL, and
+        # computed by updating that at position t-L, the parallel
+        # scan operates with a period of L. To do so we split the
+        # sequence indexing in two axes, the second of size L, and
         # run the parallel scan using the first as the sequence index.
 
-        A = A.unflatten(2, (-1, CL))
-        gated_V = gated_V.unflatten(2, (-1, CL))
-        gated_K = gated_K.unflatten(2, (-1, CL))
+        A = A.unflatten(2, (-1, L))
+        gated_V = gated_V.unflatten(2, (-1, L))
+        gated_K = gated_K.unflatten(2, (-1, L))
 
         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
@@ -644,14 +670,14 @@ class Caterpillar(nn.Module):
         # the column in the caterpillar
 
         windowed_V = moving_window(
-            self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+            self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
         )
 
         windowed_K = moving_window(
-            self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+            self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
         )
 
-        # We have an attention score for each of the CHxCL values
+        # We have an attention score for each of the RxL values
 
         ar = torch.einsum(
             "nhtd,nftld->nhtfl",
@@ -693,6 +719,8 @@ class QKVAttention(nn.Module):
         nb_heads=1,
         causal=False,
         attention_dropout=0.0,
+        logger=print,
+        **kwargs,
     ):
         super().__init__()
 
@@ -784,6 +812,8 @@ class MyGPT(nn.Module):
         dropout=0.0,
         len_max=1e5,
         attention_layer="kvrec",
+        logger=print,
+        **kwargs,
     ):
         super().__init__()
 
@@ -820,6 +850,8 @@ class MyGPT(nn.Module):
                     nb_heads=nb_heads,
                     causal=causal,
                     attention_dropout=dropout,
+                    logger=logger,
+                    **kwargs,
                 )
             elif attention_layer == "dumbrec":
                 return DumbRec(
@@ -829,6 +861,8 @@ class MyGPT(nn.Module):
                     nb_heads=nb_heads,
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
+                    logger=logger,
+                    **kwargs,
                 )
             elif attention_layer == "kvrec":
                 return KVRec(
@@ -838,6 +872,8 @@ class MyGPT(nn.Module):
                     nb_heads=nb_heads,
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
+                    logger=logger,
+                    **kwargs,
                 )
             elif attention_layer == "caterpillar":
                 return Caterpillar(
@@ -848,6 +884,8 @@ class MyGPT(nn.Module):
                     caterpillar_length=self.caterpillar_length,
                     caterpillar_height=self.caterpillar_height,
                     attention_dropout=dropout,
+                    logger=logger,
+                    **kwargs,
                 )
             else:
                 raise ValueError(f"Unknown attention type {attention_layer}.")