# mask * K[n, src_head, src_time, dk]
# + (1 - mask) * self.rec_K[:, :, t0:t1]
# )
+
+######################################################################
+
+2024 Jan 10 08:10:39 (from mygpt.py)
+
+ # That was a bad idea
+ # G = F.dropout(G, self.attention_dropout, self.training)
+
+
+######################################################################
+
+2024 Jan 10 08:46:13 (from mygpt.py)
+
+ #################################################################
+ # Flashbacks. This version sucks, about to replace it
+ if self.training and self.proba_flashback > 0.0:
+ warnings.warn("flash back", RuntimeWarning)
+ # This piece of code makes the assumption that there is
+ # nothing informative before t0, otherwise we'd have to
+ # implement a cache for V and K too. This should not be
+ # too much of a problem since this is used only during
+ # train, where full sequence are available
+
+ n = torch.arange(N, device=X.device)[:, None, None, None]
+ t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
+ dv = torch.arange(DV, device=X.device)[None, None, None, :]
+ dk = torch.arange(DK, device=X.device)[None, None, None, :]
+
+ u = (
+ torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
+ ) * CL
+
+ src_time = t - u - t0
+ src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
+
+ mask = (
+ torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
+ ).long()
+
+ self.rec_V[:, :, t0:t1] = (
+ mask * V[n, src_head, src_time, dv]
+ + (1 - mask) * self.rec_V[:, :, t0:t1]
+ )
+
+ self.rec_K[:, :, t0:t1] = (
+ mask * K[n, src_head, src_time, dk]
+ + (1 - mask) * self.rec_K[:, :, t0:t1]
+ )
+
+
+######################################################################
+
+2024 Jan 13 13:38:31 (from mygpt.py)
+
+ g= F.sigmoid(self.b_G)
+ a=1-g
+
+ print(f"\n\nSANITY {a**T}\n")
+ exit(0)
+
+
+######################################################################
+
+2024 Jan 14 13:39:37 (from mygpt.py)
+
+ epsilon = 0.5
+
+ dropout_head = (
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)
+ .float()
+ )
+
+ dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
+
+ dropout_active = (
+ torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+ ).long()
+
+ dropout_head *= dropout_active
+ dropout_tail *= dropout_active
+
+ G = (
+ G
+ + dropout_head * (1 - epsilon - G.detach())
+ - dropout_tail * G.detach()
+ )