X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=071f284f9af8ee9a0a076bf0d2454e3f66adf52f;hb=59513fa775776af525477f01925f563ddffecdb2;hp=7b6cfc0ce21dc4938cd29437314daa9b5562a204;hpb=5e18a91ef9d1c1362a21b325425990e34ad18463;p=pytorch.git diff --git a/pscan.py b/pscan.py index 7b6cfc0..071f284 100755 --- a/pscan.py +++ b/pscan.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import torch ###################################################################### @@ -62,7 +67,7 @@ class PScan(torch.autograd.Function): PScan.accrev(R) Q = ctx.Y0 / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) - return (Q * R).sum(-1), R / ctx.A_star, U + return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) pscan = PScan.apply @@ -70,33 +75,33 @@ pscan = PScan.apply ###################################################################### if __name__ == "__main__": - # Iterative implementation + N, T, D = 2, 5, 3 - A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() - X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + A = torch.randn(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() - y = Y0[:, None] + # Iterative implementation + + y = Y0 + s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] - print(f"{k} -> {y}") + s = s + y - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + s = s.sum() - print() + gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) # parallel scan Y = pscan(A, X, Y0) - for k in range(A.size(1)): - print(f"{k} -> {Y[:,k]}") + s = Y.sum() - y = Y[:, -1] + gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY0 - gY0_ref).norm())