- next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
- next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
-
- next_V = next_V.flatten(2, 3)
- next_K = next_K.flatten(2, 3)
+ next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
+ next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)