Update.
[mygptrnn.git] / fridge
diff --git a/fridge b/fridge
index 5dd85dd..d28cc89 100644 (file)
--- a/fridge
+++ b/fridge
@@ -74,3 +74,46 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
 
 
 ######################################################################
+
+######################################################################
+
+2024 Jan 07 21:38:11 (from mygpt.py)
+
+            # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,)
+
+
+######################################################################
+
+2024 Jan 09 14:24:42 (from mygpt.py)
+
+            # This piece of code makes the assumption that there is
+            # nothing informative before t0, otherwise we'd have to
+            # implement a cache for V and K too. This should not be
+            # too much of a problem since this is used only during
+            # train, where full sequence are available
+
+            # n = torch.arange(N, device=X.device)[:, None, None, None]
+            # t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
+            # dv = torch.arange(DV, device=X.device)[None, None, None, :]
+            # dk = torch.arange(DK, device=X.device)[None, None, None, :]
+
+            # u = (
+                # torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
+            # ) * CL
+
+            # src_time = t - u - t0
+            # src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
+
+            # mask = (
+                # torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
+            # ).long()
+
+            # self.rec_V[:, :, t0:t1] = (
+                # mask * V[n, src_head, src_time, dv]
+                # + (1 - mask) * self.rec_V[:, :, t0:t1]
+            # )
+
+            # self.rec_K[:, :, t0:t1] = (
+                # mask * K[n, src_head, src_time, dk]
+                # + (1 - mask) * self.rec_K[:, :, t0:t1]
+            # )