projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
263fe9c
)
Update.
author
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 01:53:53 +0000
(
02:53
+0100)
committer
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 01:53:53 +0000
(
02:53
+0100)
pscan.py
patch
|
blob
|
history
diff --git
a/pscan.py
b/pscan.py
index
6a9057e
..
1dfb442
100755
(executable)
--- a/
pscan.py
+++ b/
pscan.py
@@
-67,7
+67,7
@@
class PScan(torch.autograd.Function):
PScan.accrev(R)
Q = ctx.Y0 / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
PScan.accrev(R)
Q = ctx.Y0 / ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
- return (Q * R).sum(-1), R / ctx.A_star, U
+ return (Q * R).sum(-1), R / ctx.A_star, U
.sum(dim=1)
pscan = PScan.apply
pscan = PScan.apply
@@
-75,21
+75,28
@@
pscan = PScan.apply
######################################################################
if __name__ == "__main__":
######################################################################
if __name__ == "__main__":
+ N, T, D = 2, 5, 3
+
# Iterative implementation
# Iterative implementation
- A = torch.randn(
1, 5
, dtype=torch.float64).requires_grad_()
- X = torch.randn(
1, 5, 3
, dtype=torch.float64).requires_grad_()
- Y0 = torch.randn(
1, 3
, dtype=torch.float64).requires_grad_()
+ A = torch.randn(
N, T
, dtype=torch.float64).requires_grad_()
+ X = torch.randn(
N, T, D
, dtype=torch.float64).requires_grad_()
+ Y0 = torch.randn(
N, D
, dtype=torch.float64).requires_grad_()
- y = Y0[:, None]
+ y = Y0
+ s = 0
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
for k in range(A.size(1)):
y = A[:, k, None] * y + X[:, k]
- print(f"{k} -> {y}")
+ s = s + y
+ # print(f"{k} -> {y}")
+
+ s = s.sum()
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ # print(s)
+ print(torch.autograd.grad(s, A, retain_graph=True))
+ print(torch.autograd.grad(s, X, retain_graph=True))
+ print(torch.autograd.grad(s, Y0, retain_graph=True))
print()
print()
@@
-97,11
+104,12
@@
if __name__ == "__main__":
Y = pscan(A, X, Y0)
Y = pscan(A, X, Y0)
- for k in range(A.size(1)):
-
print(f"{k} -> {Y[:,k]}")
+
#
for k in range(A.size(1)):
+
#
print(f"{k} -> {Y[:,k]}")
- y = Y[:, -1]
+ s = Y.sum()
- print(torch.autograd.grad(y.mean(), A, retain_graph=True))
- print(torch.autograd.grad(y.mean(), X, retain_graph=True))
- print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
+ # print(s)
+ print(torch.autograd.grad(s, A, retain_graph=True))
+ print(torch.autograd.grad(s, X, retain_graph=True))
+ print(torch.autograd.grad(s, Y0, retain_graph=True))