- self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
- self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
+ #################################################################
+
+ next_V, next_K = recurrence(G, V, K)
+
+ if self.training and self.gate_dropout_proba > 0.0:
+ # G is NxHxRxT where r is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+
+ # Pick a point in each of the NxHxR timeline and set this
+ # entry and the following to 1
+ kill = (
+ torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ ).cumsum(dim=3)
+
+ # Keep these mask for only some of the NxHxR
+ kill = kill * (
+ torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
+ )
+
+ # The coefficient to keep are the complementary
+ mask = 1 - kill
+
+ masked_next_V, masked_next_K = recurrence(G * mask, V, K)
+
+ next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+ 1 - self.gate_dropout_proba
+ )
+ next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+ 1 - self.gate_dropout_proba
+ )
+
+ self.rec_V[:, :, t0:t1] = next_V
+ self.rec_K[:, :, t0:t1] = next_K