Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 16:24:46 +0000 (17:24 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 16:24:46 +0000 (17:24 +0100)
mygpt.py

index 676b921..7c8e9f4 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -476,8 +476,10 @@ class Caterpillar(nn.Module):
 
         warnings.warn("Caterpillar", RuntimeWarning)
 
-        def randw(*d):
-            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+        def randw(*d, amplitude=None):
+            if amplitude is None:
+                amplitude = 1 / math.sqrt(d[-1])
+            return nn.Parameter(amplitude * torch.randn(*d))
 
         self.caterpillar_length = caterpillar_length
         self.caterpillar_height = caterpillar_height
@@ -485,7 +487,7 @@ class Caterpillar(nn.Module):
 
         self.proba_gate_dropout = 0.0
 
-        self.w_G = randw(nb_heads, caterpillar_height, dim_model)
+        self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
         self.b_G = nn.Parameter(
             torch.full(
                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
@@ -497,8 +499,12 @@ class Caterpillar(nn.Module):
         self.w_Q = randw(nb_heads, dim_qk, dim_model)
         self.w_O = randw(dim_v * nb_heads, dim_model)
 
-        self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
-        self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
+        self.init_K_rec = randw(
+            caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
+        )
+        self.init_V_rec = randw(
+            caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
+        )
 
     def reset_inner_loss(self):
         self.acc_attention = 0
@@ -552,34 +558,65 @@ class Caterpillar(nn.Module):
         # recurrent state, or not at all.
 
         G = (
-            torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
+            torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        ######################################################################
+        # The "flashbacks"
+
+        if self.training and self.proba_gate_dropout > 0.0:
+            # This is a better implementation of "flashbacks".
+
+            # G is NxHxExT where e is the caterpillar's row.
+
+            warnings.warn("gate dropout", RuntimeWarning)
+            epsilon = 0.5
+
+            dropout_start = (
+                (
+                    torch.rand(G.size(), device=G.device)
+                    .flatten(2, 3)
+                    .sort(dim=2)
+                    .indices
+                    == 0
+                )
+                .unflatten(2, (CH, t1 - t0))
+                .float()
+            )
+
+            dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+
+            dropout_active = (
+                torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+            ).long()
+
+            dropout_start *= dropout_active
+            dropout_tail *= dropout_active
+
+            G = (
+                G
+                + dropout_start * (1 - epsilon - G.detach())
+                - dropout_tail * G.detach()
+            )
+
+        ######################################################################
+
+        # We prepare the arguments for the parallel scan
+
         # Clip the gating to avoid values greater than 1 when several
         # heads hit the same row
 
         G = G / G.sum(1, keepdim=True).clamp(min=1)
 
-        # We prepare the arguments for the parallel scan
-
         A = 1 - G.sum(1)
-        gated_V = torch.einsum("nhet,nhtd->netd", G, V)
-        gated_K = torch.einsum("nhet,nhtd->netd", G, K)
+        gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+        gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
         # We start from cached values, which matters in inference
 
         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
 
-        ######################################################################
-
-        if self.training and self.proba_gate_dropout > 0.0:
-            # This is a better implementation of "flashbacks".  A is
-            # NxExT where e is the caterpillar's row.
-
-            warnings.warn("gate dropout", RuntimeWarning)
-            epsilon = 0.5
-
         #################################################################
         # Associative scan