X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=6a9057e7644e5977879f3aae43534d24629245b3;hb=263fe9caccf360a68559bbc63d11267ff319b18d;hp=36490ff2f201f19d7e239689595500ece63e8283;hpb=8c12767fe586074920e3d4abb05e4393a145351a;p=pytorch.git diff --git a/pscan.py b/pscan.py index 36490ff..6a9057e 100755 --- a/pscan.py +++ b/pscan.py @@ -1,77 +1,107 @@ #!/usr/bin/env python -import math +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ -import torch, torchvision +# Written by Francois Fleuret -from torch import nn -from torch.nn import functional as F +import torch ###################################################################### -def naive_rec(A, X, Y0): - Y = [] - for t in range(X.size(1)): - if t == 0: - Y.append(A[:, t] * Y0 + X[:, t]) - else: - Y.append(A[:, t] * Y[-1] + X[:, t]) - - return torch.cat([y[:, None, :] for y in Y], dim=1) - - -###################################################################### - -# A is NxTx1 and X is NxTxD -# -# Returns Y defined with -# -# Y[:, 0] = A[:, 0] * Y0 + X[:,0] -# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t] - - -def pscan_rec(A, X, Y0): - if X.size(1) % 2 == 1: +class PScan(torch.autograd.Function): + # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), + # and O(log(T)) if not core-bounded, so that + # + # Y[:, 0] = Y0 + # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] + # + # can be computed as + # + # Y[:, t] = A[:, t] * Y0 + X[:, t] + + @staticmethod + def expand(A, X): + if A.size(1) == 1: + return + T = 2 * (A.size(1) // 2) + Aa = A[:, :T].view(A.size(0), T // 2, 2, -1) + Xa = X[:, :T].view(X.size(0), T // 2, 2, -1) + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + PScan.expand(Aa[:, :, 1], Xa[:, :, 1]) + Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) + Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) + if T < A.size(1): + X[:, -1].add_(A[:, -1].mul(X[:, -2])) + A[:, -1].mul_(A[:, -2]) + + # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t] + + @staticmethod + def accrev(X): if X.size(1) == 1: - return A[:, :1] * Y0[:, None] + X[:, :1] - else: - Y = pscan_rec(A[:, :-1], X[:, :-1], Y0) - return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1) + return + T = 2 * (X.size(1) // 2) + Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1) + Xa[:, :, 0].add_(Xa[:, :, 1]) + PScan.accrev(Xa[:, :, 0]) + Xa[:, :-1, 1].add_(Xa[:, 1:, 0]) + if T < X.size(1): + X[:, 0].add_(X[:, 1]) + + @staticmethod + def forward(ctx, A, X, Y0): + ctx.A = A[:, :, None].clone() + ctx.Y0 = Y0[:, None, :].clone() + ctx.A_star = A[:, :, None].clone() + ctx.X_star = X.clone() + PScan.expand(ctx.A_star, ctx.X_star) + return ctx.A_star * ctx.Y0 + ctx.X_star + + @staticmethod + def backward(ctx, grad_output): + U = grad_output * ctx.A_star + R = U.clone() + PScan.accrev(R) + Q = ctx.Y0 / ctx.A + Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) + return (Q * R).sum(-1), R / ctx.A_star, U + + +pscan = PScan.apply - A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2)) - X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2)) - - X_star = X2[:, :, 0].clone() - X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1] - - A_star = A2[:, :, 0].clone() - A_star[:, 1:] *= A2[:, :-1, 1] - - Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None] - - Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2) +###################################################################### - Y = Y.reshape(Y.size(0), -1, Y.size(-1)) +if __name__ == "__main__": + # Iterative implementation - return Y + A = torch.randn(1, 5, dtype=torch.float64).requires_grad_() + X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_() + Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_() + y = Y0[:, None] -###################################################################### + for k in range(A.size(1)): + y = A[:, k, None] * y + X[:, k] + print(f"{k} -> {y}") -N, T, D = 5, 29, 12 + print(torch.autograd.grad(y.mean(), A, retain_graph=True)) + print(torch.autograd.grad(y.mean(), X, retain_graph=True)) + print(torch.autograd.grad(y.mean(), Y0, retain_graph=True)) -A = torch.rand(N, T, 1, dtype=torch.float64) -X = torch.randint(10, (N, T, D), dtype=torch.float64) -Y0 = torch.randint(10, (N, D), dtype=torch.float64) + print() -naive_Y = naive_rec(A, X, Y0) + # parallel scan -pscan_Y = pscan_rec(A, X, Y0) + Y = pscan(A, X, Y0) -print((naive_Y - pscan_Y).pow(2).mean()) + for k in range(A.size(1)): + print(f"{k} -> {Y[:,k]}") -pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0) -pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1]) + y = Y[:, -1] -print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean()) + print(torch.autograd.grad(y.mean(), A, retain_graph=True)) + print(torch.autograd.grad(y.mean(), X, retain_graph=True)) + print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))