print((gA - gA_ref).norm())
print((gX - gX_ref).norm())
print((gY_init - gY_init_ref).norm())
+
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())