X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=1dfb44229b88f96fae0a5c8ff67f7df4ee968a5a;hb=e2886553f1cd8fce2e31b352935330cb7cb5b325;hp=7b6cfc0ce21dc4938cd29437314daa9b5562a204;hpb=5e18a91ef9d1c1362a21b325425990e34ad18463;p=pytorch.git diff --git a/pscan.py b/pscan.py index 7b6cfc0..1dfb442 100755 --- a/pscan.py +++ b/pscan.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import torch ###################################################################### @@ -62,7 +67,7 @@ class PScan(torch.autograd.Function): PScan.accrev(R) Q = ctx.Y0 / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) - return (Q * R).sum(-1), R / ctx.A_star, U + return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) pscan = PScan.apply @@ -70,21 +75,28 @@ pscan = PScan.apply ###################################################################### if __name__ == "__main__": + N, T, D = 2, 5, 3 + # Iterative implementation - A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() - X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + A = torch.randn(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() - y = Y0[:, None] + y = Y0 + s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] - print(f"{k} -> {y}") + s = s + y + # print(f"{k} -> {y}") - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + s = s.sum() + + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True)) print() @@ -92,11 +104,12 @@ if __name__ == "__main__": Y = pscan(A, X, Y0) - for k in range(A.size(1)): - print(f"{k} -> {Y[:,k]}") + # for k in range(A.size(1)): + # print(f"{k} -> {Y[:,k]}") - y = Y[:, -1] + s = Y.sum() - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True))