- # mk = (
- # torch.rand(self.rec_V[:, :, t0:t1].size()) <= self.proba_flashback
- # ).long()
- # self.rec_V[:, :, t0:t1] = V[n, src_head, src_time, dv]
- # self.rec_K[:, :, t0:t1] = K[n, src_head, src_time, dk]
+ mask_V = (torch.rand(N, CH, t1 - t0, DV) <= self.proba_flashback).long()
+ self.rec_V[:, :, t0:t1] = (
+ mask_V * V[n, src_head, src_time, dv]
+ + (1 - mask_V) * self.rec_V[:, :, t0:t1]
+ )
+
+ mask_K = (torch.rand(N, CH, t1 - t0, DK) <= self.proba_flashback).long()
+ self.rec_K[:, :, t0:t1] = (
+ mask_K * K[n, src_head, src_time, dk]
+ + (1 - mask_K) * self.rec_K[:, :, t0:t1]
+ )