+ mask_V = (
+ torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
+ ).long()
+ self.rec_V[:, :, t0:t1] = (
+ mask_V * V[n, src_head, src_time, dv]
+ + (1 - mask_V) * self.rec_V[:, :, t0:t1]
+ )
+
+ mask_K = (
+ torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback
+ ).long()
+ self.rec_K[:, :, t0:t1] = (
+ mask_K * K[n, src_head, src_time, dk]
+ + (1 - mask_K) * self.rec_K[:, :, t0:t1]
+ )