projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
59513fa
)
Update.
author
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 03:52:50 +0000
(
04:52
+0100)
committer
François Fleuret
<francois@fleuret.org>
Mon, 18 Dec 2023 03:52:50 +0000
(
04:52
+0100)
pscan.py
patch
|
blob
|
history
diff --git
a/pscan.py
b/pscan.py
index
071f284
..
3526c31
100755
(executable)
--- a/
pscan.py
+++ b/
pscan.py
@@
-14,12
+14,12
@@
class PScan(torch.autograd.Function):
# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
# and O(log(T)) if not core-bounded, so that
#
# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
# and O(log(T)) if not core-bounded, so that
#
- # Y[:, 0] = Y
0
+ # Y[:, 0] = Y
_init
# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
#
# can be computed as
#
# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
#
# can be computed as
#
- # Y[:, t] = A[:, t] * Y
0
+ X[:, t]
+ # Y[:, t] = A[:, t] * Y
_init
+ X[:, t]
@staticmethod
def expand(A, X):
@staticmethod
def expand(A, X):
@@
-51,21
+51,28
@@
class PScan(torch.autograd.Function):
if T < X.size(1):
X[:, 0].add_(X[:, 1])
if T < X.size(1):
X[:, 0].add_(X[:, 1])
+ # A is NxT, X is NxTxD, Y_init is NxD
+ #
+ # returns Y of same shape as X, with
+ #
+ # Y[:,t] = A[:,0] * Y_init + X[:,0] if t == 0
+ # = A[:,t] * Y[:,t-1] + X[:,t] otherwise
+
@staticmethod
@staticmethod
- def forward(ctx, A, X, Y
0
):
+ def forward(ctx, A, X, Y
_init
):
ctx.A = A[:, :, None].clone()
ctx.A = A[:, :, None].clone()
- ctx.Y
0 = Y0
[:, None, :].clone()
+ ctx.Y
_init = Y_init
[:, None, :].clone()
ctx.A_star = A[:, :, None].clone()
ctx.X_star = X.clone()
PScan.expand(ctx.A_star, ctx.X_star)
ctx.A_star = A[:, :, None].clone()
ctx.X_star = X.clone()
PScan.expand(ctx.A_star, ctx.X_star)
- return ctx.A_star * ctx.Y
0
+ ctx.X_star
+ return ctx.A_star * ctx.Y
_init
+ ctx.X_star
@staticmethod
def backward(ctx, grad_output):
U = grad_output * ctx.A_star
R = U.clone()
PScan.accrev(R)
@staticmethod
def backward(ctx, grad_output):
U = grad_output * ctx.A_star
R = U.clone()
PScan.accrev(R)
- Q = ctx.Y
0
/ ctx.A
+ Q = ctx.Y
_init
/ ctx.A
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
@@
-79,11
+86,11
@@
if __name__ == "__main__":
A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
- Y
0
= torch.randn(N, D, dtype=torch.float64).requires_grad_()
+ Y
_init
= torch.randn(N, D, dtype=torch.float64).requires_grad_()
# Iterative implementation
# Iterative implementation
- y = Y
0
+ y = Y
_init
s = 0
for k in range(A.size(1)):
s = 0
for k in range(A.size(1)):
@@
-92,16
+99,18
@@
if __name__ == "__main__":
s = s.sum()
s = s.sum()
- gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+ gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+ s, (A, X, Y_init), retain_graph=True
+ )
# parallel scan
# parallel scan
- Y = pscan(A, X, Y
0
)
+ Y = pscan(A, X, Y
_init
)
s = Y.sum()
s = Y.sum()
- gA, gX, gY
0 = torch.autograd.grad(s, (A, X, Y0
), retain_graph=True)
+ gA, gX, gY
_init = torch.autograd.grad(s, (A, X, Y_init
), retain_graph=True)
print((gA - gA_ref).norm())
print((gX - gX_ref).norm())
print((gA - gA_ref).norm())
print((gX - gX_ref).norm())
- print((gY
0 - gY0
_ref).norm())
+ print((gY
_init - gY_init
_ref).norm())