- self.rec_V[:, :, t0:t1] = (
- mask * V[n, src_head, src_time, dv]
- + (1 - mask) * self.rec_V[:, :, t0:t1]
+ if self.training and self.gate_dropout_proba > 0.0:
+ # G is NxHxRxT where r is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+
+ if self.gate_dropout_sync:
+ shape_kill = (N, 1, 1)
+ else:
+ shape_kill = (N, H, R)
+
+ # Pick a point in each of the NxHxR timeline and set this
+ # entry and the following to 1
+ kill = (
+ torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+ == 0
+ ).cumsum(dim=3)
+
+ # Keep these mask for only some of the NxHxR
+ kill = kill * (
+ torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba