#
# returns Y of same shape as X, with
#
- # Y[:,t] = A[:,0] * Y_init + X[:,0] if t == 0
- # = A[:,t] * Y[:,t-1] + X[:,t] otherwise
+ # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0
+ # = A[:, t] * Y[:, t-1] + X[:, t] otherwise
@staticmethod
def forward(ctx, A, X, Y_init):
print((gA - gA_ref).norm())
print((gX - gX_ref).norm())
print((gY_init - gY_init_ref).norm())
+
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())