X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=1dfb44229b88f96fae0a5c8ff67f7df4ee968a5a;hb=e2886553f1cd8fce2e31b352935330cb7cb5b325;hp=f344200160cc1d8ccb4ae260e9ea5786c0af74fb;hpb=62f1bc82f6c8107960be0009603beb5b3e386a6a;p=pytorch.git diff --git a/pscan.py b/pscan.py index f344200..1dfb442 100755 --- a/pscan.py +++ b/pscan.py @@ -1,69 +1,73 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import torch ###################################################################### -# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), -# and O(log(T)) if not core-bounded, so that -# -# Y[:, 0] = Y0 -# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] -# -# can be computed as -# -# Y[:, t] = A[:, t] * Y0 + X[:, t] - - -def expand(A, X): - if A.size(1) == 1: - return - T = 2 * (A.size(1) // 2) - Aa = A[:, :T].view(A.size(0), T // 2, 2, -1) - Xa = X[:, :T].view(X.size(0), T // 2, 2, -1) - Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) - Aa[:, :, 1].mul_(Aa[:, :, 0]) - expand(Aa[:, :, 1], Xa[:, :, 1]) - Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) - Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) - if T < A.size(1): - X[:, -1].add_(A[:, -1].mul(X[:, -2])) - A[:, -1].mul_(A[:, -2]) - - -# Computes inplace Y[:, s] = \sum_{t >= s} X[:, t] - - -def accrev(X): - if X.size(1) == 1: - return - T = 2 * (X.size(1) // 2) - Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1) - Xa[:, :, 0].add_(Xa[:, :, 1]) - accrev(Xa[:, :, 0]) - Xa[:, :-1, 1].add_(Xa[:, 1:, 0]) - if T < X.size(1): - X[:, 0].add_(X[:, 1]) - class PScan(torch.autograd.Function): + # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), + # and O(log(T)) if not core-bounded, so that + # + # Y[:, 0] = Y0 + # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] + # + # can be computed as + # + # Y[:, t] = A[:, t] * Y0 + X[:, t] + + @staticmethod + def expand(A, X): + if A.size(1) == 1: + return + T = 2 * (A.size(1) // 2) + Aa = A[:, :T].view(A.size(0), T // 2, 2, -1) + Xa = X[:, :T].view(X.size(0), T // 2, 2, -1) + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + PScan.expand(Aa[:, :, 1], Xa[:, :, 1]) + Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) + Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) + if T < A.size(1): + X[:, -1].add_(A[:, -1].mul(X[:, -2])) + A[:, -1].mul_(A[:, -2]) + + # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t] + + @staticmethod + def accrev(X): + if X.size(1) == 1: + return + T = 2 * (X.size(1) // 2) + Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1) + Xa[:, :, 0].add_(Xa[:, :, 1]) + PScan.accrev(Xa[:, :, 0]) + Xa[:, :-1, 1].add_(Xa[:, 1:, 0]) + if T < X.size(1): + X[:, 0].add_(X[:, 1]) + @staticmethod def forward(ctx, A, X, Y0): ctx.A = A[:, :, None].clone() ctx.Y0 = Y0[:, None, :].clone() ctx.A_star = A[:, :, None].clone() ctx.X_star = X.clone() - expand(ctx.A_star, ctx.X_star) + PScan.expand(ctx.A_star, ctx.X_star) return ctx.A_star * ctx.Y0 + ctx.X_star @staticmethod def backward(ctx, grad_output): U = grad_output * ctx.A_star R = U.clone() - accrev(R) + PScan.accrev(R) Q = ctx.Y0 / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) - return (Q * R).sum(-1), R / ctx.A_star, U + return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) pscan = PScan.apply @@ -71,29 +75,41 @@ pscan = PScan.apply ###################################################################### if __name__ == "__main__": - A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() - X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + N, T, D = 2, 5, 3 - y = Y0[:, None] + # Iterative implementation + + A = torch.randn(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() + + y = Y0 + s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] - print(f"{k} -> {y}") + s = s + y + # print(f"{k} -> {y}") - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + s = s.sum() - Y = pscan(A, X, Y0) + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True)) print() - for k in range(A.size(1)): - print(f"{k} -> {Y[:,k]}") + # parallel scan + + Y = pscan(A, X, Y0) + + # for k in range(A.size(1)): + # print(f"{k} -> {Y[:,k]}") - y = Y[:, -1] + s = Y.sum() - print(torch.autograd.grad(y.mean(), A, retain_graph=True)) - print(torch.autograd.grad(y.mean(), X, retain_graph=True)) - print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) + # print(s) + print(torch.autograd.grad(s, A, retain_graph=True)) + print(torch.autograd.grad(s, X, retain_graph=True)) + print(torch.autograd.grad(s, Y0, retain_graph=True))