Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 12:06:27 +0000 (13:06 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 12:06:27 +0000 (13:06 +0100)
mygpt.py

index 492a9bb..5451584 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -617,8 +617,6 @@ class Caterpillar(nn.Module):
             init_rec_V = self.rec_V[:, :, t0 - L : t0]
             init_rec_K = self.rec_K[:, :, t0 - L : t0]
 
-            # Associative scan
-
             # Here there is a trick: Since the stack at position t is
             # computed by updating that at position t-L, the parallel
             # scan operates with a period of L. To do so we split the
@@ -646,9 +644,16 @@ class Caterpillar(nn.Module):
 
             warnings.warn("gate dropout", RuntimeWarning)
 
+            # kill = (
+            # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+            # ).float()
+
             kill = (
-                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
-            ).float()
+                torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+            ).cumsum(dim=3)
+            kill = kill * (
+                torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
+            )
 
             mask = 1 - kill