######################################################################
-# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
-# and O(log(T)) if not core-bounded, so that
-#
-# Y[:, 0] = Y0
-# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
-#
-# can be computed as
-#
-# Y[:, t] = A[:, t] * Y0 + X[:, t]
-
-
-def expand(A, X):
- if A.size(1) == 1:
- return
- T = 2 * (A.size(1) // 2)
- Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
- Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
- Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
- Aa[:, :, 1].mul_(Aa[:, :, 0])
- expand(Aa[:, :, 1], Xa[:, :, 1])
- Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
- Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
- if T < A.size(1):
- X[:, -1].add_(A[:, -1].mul(X[:, -2]))
- A[:, -1].mul_(A[:, -2])
-
-
-# Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
-
-
-def accrev(X):
- if X.size(1) == 1:
- return
- T = 2 * (X.size(1) // 2)
- Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
- Xa[:, :, 0].add_(Xa[:, :, 1])
- accrev(Xa[:, :, 0])
- Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
- if T < X.size(1):
- X[:, 0].add_(X[:, 1])
-
class PScan(torch.autograd.Function):
+ # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+ # and O(log(T)) if not core-bounded, so that
+ #
+ # Y[:, 0] = Y0
+ # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+ #
+ # can be computed as
+ #
+ # Y[:, t] = A[:, t] * Y0 + X[:, t]
+
+ @staticmethod
+ def expand(A, X):
+ if A.size(1) == 1:
+ return
+ T = 2 * (A.size(1) // 2)
+ Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
+ Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
+ PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
+ Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+ Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+ if T < A.size(1):
+ X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+ A[:, -1].mul_(A[:, -2])
+
+ # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
+
+ @staticmethod
+ def accrev(X):
+ if X.size(1) == 1:
+ return
+ T = 2 * (X.size(1) // 2)
+ Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
+ Xa[:, :, 0].add_(Xa[:, :, 1])
+ PScan.accrev(Xa[:, :, 0])
+ Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
+ if T < X.size(1):
+ X[:, 0].add_(X[:, 1])
+
@staticmethod
def forward(ctx, A, X, Y0):
ctx.A = A[:, :, None].clone()
ctx.Y0 = Y0[:, None, :].clone()
ctx.A_star = A[:, :, None].clone()
ctx.X_star = X.clone()
- expand(ctx.A_star, ctx.X_star)
+ PScan.expand(ctx.A_star, ctx.X_star)
return ctx.A_star * ctx.Y0 + ctx.X_star
@staticmethod
def backward(ctx, grad_output):
U = grad_output * ctx.A_star
R = U.clone()
- accrev(R)
+ PScan.accrev(R)
Q = ctx.Y0 / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
return (Q * R).sum(-1), R / ctx.A_star, U
######################################################################
if __name__ == "__main__":
+ # Iterative implementation
+
A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
print(torch.autograd.grad(y.mean(), X, retain_graph=True))
print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
- Y = pscan(A, X, Y0)
-
print()
+ # parallel scan
+
+ Y = pscan(A, X, Y0)
+
for k in range(A.size(1)):
print(f"{k} -> {Y[:,k]}")