projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
e288655
)
Update.
author
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 01:57:06 +0000
(
02:57
+0100)
committer
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 01:57:06 +0000
(
02:57
+0100)
pscan.py
patch
|
blob
|
history
diff --git
a/pscan.py
b/pscan.py
index
1dfb442
..
071f284
100755
(executable)
--- a/
pscan.py
+++ b/
pscan.py
@@
-77,39
+77,31
@@
pscan = PScan.apply
if __name__ == "__main__":
N, T, D = 2, 5, 3
if __name__ == "__main__":
N, T, D = 2, 5, 3
- # Iterative implementation
-
A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+ # Iterative implementation
+
y = Y0
s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
s = s + y
y = Y0
s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
s = s + y
- # print(f"{k} -> {y}")
s = s.sum()
s = s.sum()
- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
-
- print()
+ gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
# parallel scan
Y = pscan(A, X, Y0)
# parallel scan
Y = pscan(A, X, Y0)
- # for k in range(A.size(1)):
- # print(f"{k} -> {Y[:,k]}")
-
s = Y.sum()
s = Y.sum()
- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
+ gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY0 - gY0_ref).norm())