Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 01:53:53 +0000 (02:53 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 01:53:53 +0000 (02:53 +0100)
pscan.py

index 6a9057e..1dfb442 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -67,7 +67,7 @@ class PScan(torch.autograd.Function):
         PScan.accrev(R)
         Q = ctx.Y0 / ctx.A
         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
-        return (Q * R).sum(-1), R / ctx.A_star, U
+        return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
 
 
 pscan = PScan.apply
@@ -75,21 +75,28 @@ pscan = PScan.apply
 ######################################################################
 
 if __name__ == "__main__":
+    N, T, D = 2, 5, 3
+
     # Iterative implementation
 
-    A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
-    X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
-    Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
+    A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+    X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+    Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
 
-    y = Y0[:, None]
+    y = Y0
+    s = 0
 
     for k in range(A.size(1)):
         y = A[:, k, None] * y + X[:, k]
-        print(f"{k} -> {y}")
+        s = s + y
+        # print(f"{k} -> {y}")
+
+    s = s.sum()
 
-    print(torch.autograd.grad(y.mean(), A, retain_graph=True))
-    print(torch.autograd.grad(y.mean(), X, retain_graph=True))
-    print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+    # print(s)
+    print(torch.autograd.grad(s, A, retain_graph=True))
+    print(torch.autograd.grad(s, X, retain_graph=True))
+    print(torch.autograd.grad(s, Y0, retain_graph=True))
 
     print()
 
@@ -97,11 +104,12 @@ if __name__ == "__main__":
 
     Y = pscan(A, X, Y0)
 
-    for k in range(A.size(1)):
-        print(f"{k} -> {Y[:,k]}")
+    for k in range(A.size(1)):
+    # print(f"{k} -> {Y[:,k]}")
 
-    y = Y[:, -1]
+    s = Y.sum()
 
-    print(torch.autograd.grad(y.mean(), A, retain_graph=True))
-    print(torch.autograd.grad(y.mean(), X, retain_graph=True))
-    print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+    # print(s)
+    print(torch.autograd.grad(s, A, retain_graph=True))
+    print(torch.autograd.grad(s, X, retain_graph=True))
+    print(torch.autograd.grad(s, Y0, retain_graph=True))