3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 ######################################################################
13 class PScan(torch.autograd.Function):
14 # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
15 # and O(log(T)) if not core-bounded, so that
18 # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
22 # Y[:, t] = A[:, t] * Y_init + X[:, t]
28 T = 2 * (A.size(1) // 2)
29 Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
30 Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
31 Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
32 Aa[:, :, 1].mul_(Aa[:, :, 0])
33 PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
34 Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
35 Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
37 X[:, -1].add_(A[:, -1].mul(X[:, -2]))
38 A[:, -1].mul_(A[:, -2])
40 # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
46 T = 2 * (X.size(1) // 2)
47 Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
48 Xa[:, :, 0].add_(Xa[:, :, 1])
49 PScan.accrev(Xa[:, :, 0])
50 Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
54 # A is NxT, X is NxTxD, Y_init is NxD
56 # returns Y of same shape as X, with
58 # Y[:,t] = A[:,0] * Y_init + X[:,0] if t == 0
59 # = A[:,t] * Y[:,t-1] + X[:,t] otherwise
62 def forward(ctx, A, X, Y_init):
63 ctx.A = A[:, :, None].clone()
64 ctx.Y_init = Y_init[:, None, :].clone()
65 ctx.A_star = A[:, :, None].clone()
66 ctx.X_star = X.clone()
67 PScan.expand(ctx.A_star, ctx.X_star)
68 return ctx.A_star * ctx.Y_init + ctx.X_star
71 def backward(ctx, grad_output):
72 U = grad_output * ctx.A_star
75 Q = ctx.Y_init / ctx.A
76 Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
77 return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
82 ######################################################################
84 if __name__ == "__main__":
87 A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
88 X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
89 Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
91 # Iterative implementation
96 for k in range(A.size(1)):
97 y = A[:, k, None] * y + X[:, k]
102 gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
103 s, (A, X, Y_init), retain_graph=True
108 Y = pscan(A, X, Y_init)
112 gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
114 print((gA - gA_ref).norm())
115 print((gX - gX_ref).norm())
116 print((gY_init - gY_init_ref).norm())
118 Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
119 Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
121 print((Y - torch.cat([Y1, Y2], dim=1)).norm())