if __name__ == "__main__":
N, T, D = 2, 5, 3
- # Iterative implementation
-
A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+ # Iterative implementation
+
y = Y0
s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
s = s + y
- # print(f"{k} -> {y}")
s = s.sum()
- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
-
- print()
+ gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
# parallel scan
Y = pscan(A, X, Y0)
- # for k in range(A.size(1)):
- # print(f"{k} -> {Y[:,k]}")
-
s = Y.sum()
- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
+ gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY0 - gY0_ref).norm())