PScan.accrev(R)
Q = ctx.Y0 / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
- return (Q * R).sum(-1), R / ctx.A_star, U
+ return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
pscan = PScan.apply
######################################################################
if __name__ == "__main__":
+ N, T, D = 2, 5, 3
+
# Iterative implementation
- A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
- X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
- Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
+ A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
- y = Y0[:, None]
+ y = Y0
+ s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
- print(f"{k} -> {y}")
+ s = s + y
+ # print(f"{k} -> {y}")
+
+ s = s.sum()
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ # print(s)
+ print(torch.autograd.grad(s, A, retain_graph=True))
+ print(torch.autograd.grad(s, X, retain_graph=True))
+ print(torch.autograd.grad(s, Y0, retain_graph=True))
print()
Y = pscan(A, X, Y0)
- for k in range(A.size(1)):
- print(f"{k} -> {Y[:,k]}")
+ # for k in range(A.size(1)):
+ # print(f"{k} -> {Y[:,k]}")
- y = Y[:, -1]
+ s = Y.sum()
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ # print(s)
+ print(torch.autograd.grad(s, A, retain_graph=True))
+ print(torch.autograd.grad(s, X, retain_graph=True))
+ print(torch.autograd.grad(s, Y0, retain_graph=True))