Update.
[mygptrnn.git] / fridge
diff --git a/fridge b/fridge
index 194c4e6..f87c1df 100644 (file)
--- a/fridge
+++ b/fridge
@@ -177,3 +177,30 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
         print(f"\n\nSANITY {a**T}\n")
         exit(0)
 
+
+######################################################################
+
+2024 Jan 14 13:39:37 (from mygpt.py)
+
+            epsilon = 0.5
+
+            dropout_head = (
+                (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+                .expand_as(G)
+                .float()
+            )
+
+            dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
+
+            dropout_active = (
+                torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+            ).long()
+
+            dropout_head *= dropout_active
+            dropout_tail *= dropout_active
+
+            G = (
+                G
+                + dropout_head * (1 - epsilon - G.detach())
+                - dropout_tail * G.detach()
+            )