init_rec_V = self.rec_V[:, :, t0 - L : t0]
init_rec_K = self.rec_K[:, :, t0 - L : t0]
- # Associative scan
-
# Here there is a trick: Since the stack at position t is
# computed by updating that at position t-L, the parallel
# scan operates with a period of L. To do so we split the
warnings.warn("gate dropout", RuntimeWarning)
+ # kill = (
+ # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+ # ).float()
+
kill = (
- torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
- ).float()
+ torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ ).cumsum(dim=3)
+ kill = kill * (
+ torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
+ )
mask = 1 - kill