Update.
[pytorch.git] / pscan.py
index 3526c31..891ec55 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -55,8 +55,8 @@ class PScan(torch.autograd.Function):
     #
     # returns Y of same shape as X, with
     #
-    # Y[:,t] = A[:,0] * Y_init   + X[:,0] if t == 0
-    #        = A[:,t] * Y[:,t-1] + X[:,t] otherwise
+    # Y[:, t] = A[:, 0] * Y_init   + X[:, 0] if t == 0
+    #        = A[:, t] * Y[:, t-1] + X[:, t] otherwise
 
     @staticmethod
     def forward(ctx, A, X, Y_init):
@@ -114,3 +114,8 @@ if __name__ == "__main__":
     print((gA - gA_ref).norm())
     print((gX - gX_ref).norm())
     print((gY_init - gY_init_ref).norm())
+
+    Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+    Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+    print((Y - torch.cat([Y1, Y2], dim=1)).norm())