#!/usr/bin/env python
-import torch
-
-######################################################################
-
-
-def naive_rec(A, X, Y0):
- Y = []
- for t in range(X.size(1)):
- if t == 0:
- Y.append(A[:, t] * Y0 + X[:, t])
- else:
- Y.append(A[:, t] * Y[-1] + X[:, t])
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
- return torch.cat([y[:, None, :] for y in Y], dim=1)
+# Written by Francois Fleuret <francois@fleuret.org>
+import torch
######################################################################
-# A is NxTx1
-# X is NxTxD
-# Y0 is NxD
-#
-# Returns Y defined with
-#
-# Y[:, 0] = A[:, 0] * Y0 + X[:,0]
-# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t]
-
-def pscan_rec(A, X, Y0):
- if X.size(1) % 2 == 1:
+class PScan(torch.autograd.Function):
+ # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+ # and O(log(T)) if not core-bounded, so that
+ #
+ # Y[:, 0] = Y_init
+ # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+ #
+ # can be computed as
+ #
+ # Y[:, t] = A[:, t] * Y_init + X[:, t]
+
+ @staticmethod
+ def expand(A, X):
+ if A.size(1) == 1:
+ return
+ T = 2 * (A.size(1) // 2)
+ Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
+ Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
+ PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
+ Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+ Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+ if T < A.size(1):
+ X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+ A[:, -1].mul_(A[:, -2])
+
+ # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
+
+ @staticmethod
+ def accrev(X):
if X.size(1) == 1:
- return A[:, :1] * Y0[:, None] + X[:, :1]
- else:
- Y = pscan_rec(A[:, :-1], X[:, :-1], Y0)
- return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1)
-
- A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2))
- X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2))
+ return
+ T = 2 * (X.size(1) // 2)
+ Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
+ Xa[:, :, 0].add_(Xa[:, :, 1])
+ PScan.accrev(Xa[:, :, 0])
+ Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
+ if T < X.size(1):
+ X[:, 0].add_(X[:, 1])
+
+ # A is NxT, X is NxTxD, Y_init is NxD
+ #
+ # returns Y of same shape as X, with
+ #
+ # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0
+ # = A[:, t] * Y[:, t-1] + X[:, t] otherwise
+
+ @staticmethod
+ def forward(ctx, A, X, Y_init):
+ ctx.A = A[:, :, None].clone()
+ ctx.Y_init = Y_init[:, None, :].clone()
+ ctx.A_star = A[:, :, None].clone()
+ ctx.X_star = X.clone()
+ PScan.expand(ctx.A_star, ctx.X_star)
+ return ctx.A_star * ctx.Y_init + ctx.X_star
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ U = grad_output * ctx.A_star
+ R = U.clone()
+ PScan.accrev(R)
+ Q = ctx.Y_init / ctx.A
+ Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
+ return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
+
+
+pscan = PScan.apply
- X_star = X2[:, :, 0].clone()
- X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1]
+######################################################################
- A_star = A2[:, :, 0].clone()
- A_star[:, 1:] *= A2[:, :-1, 1]
+if __name__ == "__main__":
+ N, T, D = 2, 5, 3
- Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None]
+ A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
- Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2)
+ # Iterative implementation
- Y = Y.reshape(Y.size(0), -1, Y.size(-1))
+ y = Y_init
+ s = 0
- return Y
+ for k in range(A.size(1)):
+ y = A[:, k, None] * y + X[:, k]
+ s = s + y
+ s = s.sum()
-######################################################################
+ gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+ s, (A, X, Y_init), retain_graph=True
+ )
-N, T, D = 5, 29, 12
+ # parallel scan
-A = torch.rand(N, T, 1, dtype=torch.float64)
-X = torch.randint(10, (N, T, D), dtype=torch.float64)
-Y0 = torch.randint(10, (N, D), dtype=torch.float64)
+ Y = pscan(A, X, Y_init)
-naive_Y = naive_rec(A, X, Y0)
+ s = Y.sum()
-pscan_Y = pscan_rec(A, X, Y0)
+ gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
-print((naive_Y - pscan_Y).pow(2).mean())
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY_init - gY_init_ref).norm())
-pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0)
-pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1])
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
-print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean())
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())