warnings.warn("gate dropout", RuntimeWarning)
+ if self.gate_dropout_sync:
+ shape_kill = (N, 1, 1)
+ else:
+ shape_kill = (N, H, R)
+
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
- torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+ == 0
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
- torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
+ torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
)
# The coefficient to keep are the complementary