Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 03:52:50 +0000 (04:52 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 03:52:50 +0000 (04:52 +0100)
pscan.py

index 071f284..3526c31 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -14,12 +14,12 @@ class PScan(torch.autograd.Function):
     # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
     # and O(log(T)) if not core-bounded, so that
     #
-    # Y[:, 0] = Y0
+    # Y[:, 0] = Y_init
     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
     #
     # can be computed as
     #
-    # Y[:, t] = A[:, t] * Y0 + X[:, t]
+    # Y[:, t] = A[:, t] * Y_init + X[:, t]
 
     @staticmethod
     def expand(A, X):
@@ -51,21 +51,28 @@ class PScan(torch.autograd.Function):
         if T < X.size(1):
             X[:, 0].add_(X[:, 1])
 
+    # A is NxT, X is NxTxD, Y_init is NxD
+    #
+    # returns Y of same shape as X, with
+    #
+    # Y[:,t] = A[:,0] * Y_init   + X[:,0] if t == 0
+    #        = A[:,t] * Y[:,t-1] + X[:,t] otherwise
+
     @staticmethod
-    def forward(ctx, A, X, Y0):
+    def forward(ctx, A, X, Y_init):
         ctx.A = A[:, :, None].clone()
-        ctx.Y0 = Y0[:, None, :].clone()
+        ctx.Y_init = Y_init[:, None, :].clone()
         ctx.A_star = A[:, :, None].clone()
         ctx.X_star = X.clone()
         PScan.expand(ctx.A_star, ctx.X_star)
-        return ctx.A_star * ctx.Y0 + ctx.X_star
+        return ctx.A_star * ctx.Y_init + ctx.X_star
 
     @staticmethod
     def backward(ctx, grad_output):
         U = grad_output * ctx.A_star
         R = U.clone()
         PScan.accrev(R)
-        Q = ctx.Y0 / ctx.A
+        Q = ctx.Y_init / ctx.A
         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
         return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
 
@@ -79,11 +86,11 @@ if __name__ == "__main__":
 
     A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
     X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
-    Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+    Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
 
     # Iterative implementation
 
-    y = Y0
+    y = Y_init
     s = 0
 
     for k in range(A.size(1)):
@@ -92,16 +99,18 @@ if __name__ == "__main__":
 
     s = s.sum()
 
-    gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+    gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+        s, (A, X, Y_init), retain_graph=True
+    )
 
     # parallel scan
 
-    Y = pscan(A, X, Y0)
+    Y = pscan(A, X, Y_init)
 
     s = Y.sum()
 
-    gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+    gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
 
     print((gA - gA_ref).norm())
     print((gX - gX_ref).norm())
-    print((gY0 - gY0_ref).norm())
+    print((gY_init - gY_init_ref).norm())