- dropout_start = (
- (
- torch.rand(G.size(), device=G.device)
- .flatten(2, 3)
- .sort(dim=2)
- .indices
- == 0
- )
- .unflatten(2, (CH, t1 - t0))
+ dropout_head = (
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)