PScan.accrev(R)
Q = ctx.Y0 / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
- return (Q * R).sum(-1), R / ctx.A_star, U
+ return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
pscan = PScan.apply
######################################################################
if __name__ == "__main__":
- # Iterative implementation
+ N, T, D = 2, 5, 3
+
+ A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
- A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
- X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
- Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
+ # Iterative implementation
- y = Y0[:, None]
+ y = Y0
+ s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
- print(f"{k} -> {y}")
+ s = s + y
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ s = s.sum()
- print()
+ gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
# parallel scan
Y = pscan(A, X, Y0)
- for k in range(A.size(1)):
- print(f"{k} -> {Y[:,k]}")
+ s = Y.sum()
- y = Y[:, -1]
+ gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY0 - gY0_ref).norm())