- n = torch.arange(N, device=X.device)[:, None, None, None]
- t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
- dv = torch.arange(DV, device=X.device)[None, None, None, :]
- dk = torch.arange(DK, device=X.device)[None, None, None, :]
+ next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+ next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)