- n = torch.arange(N, device=X.device)[:, None, None, None]
- t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
- dv = torch.arange(DV, device=X.device)[None, None, None, :]
- dk = torch.arange(DK, device=X.device)[None, None, None, :]
+ next_V = next_V.flatten(2, 3)
+ next_K = next_K.flatten(2, 3)