Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 23:19:12 +0000 (00:19 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 23:19:12 +0000 (00:19 +0100)
mygpt.py

index 5451584..fb24b9a 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -644,17 +644,18 @@ class Caterpillar(nn.Module):
 
             warnings.warn("gate dropout", RuntimeWarning)
 
-            # kill = (
-            # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
-            # ).float()
-
+            # Pick a point in each of the NxHxR timeline and set this
+            # entry and the following to 1
             kill = (
                 torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
             ).cumsum(dim=3)
+
+            # Keep these mask for only some of the NxHxR
             kill = kill * (
                 torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
             )
 
+            # The coefficient to keep are the complementary
             mask = 1 - kill
 
             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
@@ -674,8 +675,8 @@ class Caterpillar(nn.Module):
 
         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
 
-        # We build tensors NxHxTxFxL where N is the sample index, H
-        # the head, T the time, F the row in the caterpillar, and L
+        # We build tensors NxHxTxRxL where N is the sample index, H
+        # the head, T the time, R the row in the caterpillar, and L
         # the column in the caterpillar
 
         windowed_V = moving_window(
@@ -689,7 +690,7 @@ class Caterpillar(nn.Module):
         # We have an attention score for each of the RxL values
 
         ar = torch.einsum(
-            "nhtd,nftld->nhtfl",
+            "nhtd,nrtld->nhtrl",
             Q,
             windowed_K,
         ) / math.sqrt(DK)