Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 00:39:59 +0000 (01:39 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 00:39:59 +0000 (01:39 +0100)
pscan.py

index d009692..f344200 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -4,71 +4,96 @@ import torch
 
 ######################################################################
 
-
-def naive_rec(A, X, Y0):
-    Y = []
-    for t in range(X.size(1)):
-        if t == 0:
-            Y.append(A[:, t] * Y0 + X[:, t])
-        else:
-            Y.append(A[:, t] * Y[-1] + X[:, t])
-
-    return torch.cat([y[:, None, :] for y in Y], dim=1)
-
-
-######################################################################
-
-# A is NxTx1
-# X is NxTxD
-# Y0 is NxD
+# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+# and O(log(T)) if not core-bounded, so that
 #
-# Returns Y defined with
+# Y[:, 0] = Y0
+# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
 #
-#           Y[:, 0] = A[:, 0] * Y0 + X[:,0]
-# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t]
-
-
-def pscan_rec(A, X, Y0):
-    if X.size(1) % 2 == 1:
-        if X.size(1) == 1:
-            return A[:, :1] * Y0[:, None] + X[:, :1]
-        else:
-            Y = pscan_rec(A[:, :-1], X[:, :-1], Y0)
-            return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1)
-
-    A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2))
-    X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2))
-
-    X_star = X2[:, :, 0].clone()
-    X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1]
-
-    A_star = A2[:, :, 0].clone()
-    A_star[:, 1:] *= A2[:, :-1, 1]
-
-    Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None]
-
-    Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2)
-
-    Y = Y.reshape(Y.size(0), -1, Y.size(-1))
+# can be computed as
+#
+# Y[:, t] = A[:, t] * Y0 + X[:, t]
+
+
+def expand(A, X):
+    if A.size(1) == 1:
+        return
+    T = 2 * (A.size(1) // 2)
+    Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
+    Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
+    Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+    Aa[:, :, 1].mul_(Aa[:, :, 0])
+    expand(Aa[:, :, 1], Xa[:, :, 1])
+    Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+    Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+    if T < A.size(1):
+        X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+        A[:, -1].mul_(A[:, -2])
+
+
+# Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
+
+
+def accrev(X):
+    if X.size(1) == 1:
+        return
+    T = 2 * (X.size(1) // 2)
+    Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
+    Xa[:, :, 0].add_(Xa[:, :, 1])
+    accrev(Xa[:, :, 0])
+    Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
+    if T < X.size(1):
+        X[:, 0].add_(X[:, 1])
+
+
+class PScan(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, A, X, Y0):
+        ctx.A = A[:, :, None].clone()
+        ctx.Y0 = Y0[:, None, :].clone()
+        ctx.A_star = A[:, :, None].clone()
+        ctx.X_star = X.clone()
+        expand(ctx.A_star, ctx.X_star)
+        return ctx.A_star * ctx.Y0 + ctx.X_star
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        U = grad_output * ctx.A_star
+        R = U.clone()
+        accrev(R)
+        Q = ctx.Y0 / ctx.A
+        Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
+        return (Q * R).sum(-1), R / ctx.A_star, U
+
+
+pscan = PScan.apply
 
-    return Y
+######################################################################
 
+if __name__ == "__main__":
+    A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
+    X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
+    Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
 
-######################################################################
+    y = Y0[:, None]
 
-N, T, D = 5, 29, 12
+    for k in range(A.size(1)):
+        y = A[:, k, None] * y + X[:, k]
+        print(f"{k} -> {y}")
 
-A = torch.rand(N, T, 1, dtype=torch.float64)
-X = torch.randint(10, (N, T, D), dtype=torch.float64)
-Y0 = torch.randint(10, (N, D), dtype=torch.float64)
+    print(torch.autograd.grad(y.mean(), A, retain_graph=True))
+    print(torch.autograd.grad(y.mean(), X, retain_graph=True))
+    print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
 
-naive_Y = naive_rec(A, X, Y0)
+    Y = pscan(A, X, Y0)
 
-pscan_Y = pscan_rec(A, X, Y0)
+    print()
 
-print((naive_Y - pscan_Y).pow(2).mean())
+    for k in range(A.size(1)):
+        print(f"{k} -> {Y[:,k]}")
 
-pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0)
-pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1])
+    y = Y[:, -1]
 
-print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean())
+    print(torch.autograd.grad(y.mean(), A, retain_graph=True))
+    print(torch.autograd.grad(y.mean(), X, retain_graph=True))
+    print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))