# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
# and O(log(T)) if not core-bounded, so that
#
- # Y[:, 0] = Y0
+ # Y[:, 0] = Y_init
# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
#
# can be computed as
#
- # Y[:, t] = A[:, t] * Y0 + X[:, t]
+ # Y[:, t] = A[:, t] * Y_init + X[:, t]
@staticmethod
def expand(A, X):
if T < X.size(1):
X[:, 0].add_(X[:, 1])
+ # A is NxT, X is NxTxD, Y_init is NxD
+ #
+ # returns Y of same shape as X, with
+ #
+ # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0
+ # = A[:, t] * Y[:, t-1] + X[:, t] otherwise
+
@staticmethod
- def forward(ctx, A, X, Y0):
+ def forward(ctx, A, X, Y_init):
ctx.A = A[:, :, None].clone()
- ctx.Y0 = Y0[:, None, :].clone()
+ ctx.Y_init = Y_init[:, None, :].clone()
ctx.A_star = A[:, :, None].clone()
ctx.X_star = X.clone()
PScan.expand(ctx.A_star, ctx.X_star)
- return ctx.A_star * ctx.Y0 + ctx.X_star
+ return ctx.A_star * ctx.Y_init + ctx.X_star
@staticmethod
def backward(ctx, grad_output):
U = grad_output * ctx.A_star
R = U.clone()
PScan.accrev(R)
- Q = ctx.Y0 / ctx.A
+ Q = ctx.Y_init / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
- return (Q * R).sum(-1), R / ctx.A_star, U
+ return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
pscan = PScan.apply
######################################################################
if __name__ == "__main__":
- # Iterative implementation
+ N, T, D = 2, 5, 3
- A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
- X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
- Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
+ A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
- y = Y0[:, None]
+ # Iterative implementation
+
+ y = Y_init
+ s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
- print(f"{k} -> {y}")
+ s = s + y
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ s = s.sum()
- print()
+ gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+ s, (A, X, Y_init), retain_graph=True
+ )
# parallel scan
- Y = pscan(A, X, Y0)
+ Y = pscan(A, X, Y_init)
- for k in range(A.size(1)):
- print(f"{k} -> {Y[:,k]}")
+ s = Y.sum()
+
+ gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY_init - gY_init_ref).norm())
- y = Y[:, -1]
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())