X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=pscan.py;h=891ec5578e118f48d8f5e933d3f4603e1703e9b1;hp=7b6cfc0ce21dc4938cd29437314daa9b5562a204;hb=HEAD;hpb=5e18a91ef9d1c1362a21b325425990e34ad18463 diff --git a/pscan.py b/pscan.py index 7b6cfc0..891ec55 100755 --- a/pscan.py +++ b/pscan.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import torch ###################################################################### @@ -9,12 +14,12 @@ class PScan(torch.autograd.Function): # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), # and O(log(T)) if not core-bounded, so that # - # Y[:, 0] = Y0 + # Y[:, 0] = Y_init # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] # # can be computed as # - # Y[:, t] = A[:, t] * Y0 + X[:, t] + # Y[:, t] = A[:, t] * Y_init + X[:, t] @staticmethod def expand(A, X): @@ -46,23 +51,30 @@ class PScan(torch.autograd.Function): if T < X.size(1): X[:, 0].add_(X[:, 1]) + # A is NxT, X is NxTxD, Y_init is NxD + # + # returns Y of same shape as X, with + # + # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0 + # = A[:, t] * Y[:, t-1] + X[:, t] otherwise + @staticmethod - def forward(ctx, A, X, Y0): + def forward(ctx, A, X, Y_init): ctx.A = A[:, :, None].clone() - ctx.Y0 = Y0[:, None, :].clone() + ctx.Y_init = Y_init[:, None, :].clone() ctx.A_star = A[:, :, None].clone() ctx.X_star = X.clone() PScan.expand(ctx.A_star, ctx.X_star) - return ctx.A_star * ctx.Y0 + ctx.X_star + return ctx.A_star * ctx.Y_init + ctx.X_star @staticmethod def backward(ctx, grad_output): U = grad_output * ctx.A_star R = U.clone() PScan.accrev(R) - Q = ctx.Y0 / ctx.A + Q = ctx.Y_init / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) - return (Q * R).sum(-1), R / ctx.A_star, U + return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) pscan = PScan.apply @@ -70,33 +82,40 @@ pscan = PScan.apply ###################################################################### if __name__ == "__main__": - # Iterative implementation + N, T, D = 2, 5, 3 - A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() - X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + A = torch.randn(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() - y = Y0[:, None] + # Iterative implementation + + y = Y_init + s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] - print(f"{k} -> {y}") + s = s + y - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + s = s.sum() - print() + gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( + s, (A, X, Y_init), retain_graph=True + ) # parallel scan - Y = pscan(A, X, Y0) + Y = pscan(A, X, Y_init) - for k in range(A.size(1)): - print(f"{k} -> {Y[:,k]}") + s = Y.sum() + + gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) + + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY_init - gY_init_ref).norm()) - y = Y[:, -1] + Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + print((Y - torch.cat([Y1, Y2], dim=1)).norm())