From e2886553f1cd8fce2e31b352935330cb7cb5b325 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 18 Dec 2023 02:53:53 +0100 Subject: [PATCH] Update. --- pscan.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/pscan.py b/pscan.py index 6a9057e..1dfb442 100755 --- a/pscan.py +++ b/pscan.py @@ -67,7 +67,7 @@ class PScan(torch.autograd.Function): PScan.accrev(R) Q = ctx.Y0 / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) - return (Q * R).sum(-1), R / ctx.A_star, U + return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) pscan = PScan.apply @@ -75,21 +75,28 @@ pscan = PScan.apply ###################################################################### if __name__ == "__main__": + N, T, D = 2, 5, 3 + # Iterative implementation - A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() - X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + A = torch.randn(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() - y = Y0[:, None] + y = Y0 + s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] - print(f"{k} -> {y}") + s = s + y + # print(f"{k} -> {y}") + + s = s.sum() - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True)) print() @@ -97,11 +104,12 @@ if __name__ == "__main__": Y = pscan(A, X, Y0) - for k in range(A.size(1)): - print(f"{k} -> {Y[:,k]}") + # for k in range(A.size(1)): + # print(f"{k} -> {Y[:,k]}") - y = Y[:, -1] + s = Y.sum() - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True)) -- 2.39.5