5 ######################################################################
8 class PScan(torch.autograd.Function):
9 # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
10 # and O(log(T)) if not core-bounded, so that
13 # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
17 # Y[:, t] = A[:, t] * Y0 + X[:, t]
23 T = 2 * (A.size(1) // 2)
24 Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
25 Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
26 Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
27 Aa[:, :, 1].mul_(Aa[:, :, 0])
28 PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
29 Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
30 Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
32 X[:, -1].add_(A[:, -1].mul(X[:, -2]))
33 A[:, -1].mul_(A[:, -2])
35 # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
41 T = 2 * (X.size(1) // 2)
42 Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
43 Xa[:, :, 0].add_(Xa[:, :, 1])
44 PScan.accrev(Xa[:, :, 0])
45 Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
50 def forward(ctx, A, X, Y0):
51 ctx.A = A[:, :, None].clone()
52 ctx.Y0 = Y0[:, None, :].clone()
53 ctx.A_star = A[:, :, None].clone()
54 ctx.X_star = X.clone()
55 PScan.expand(ctx.A_star, ctx.X_star)
56 return ctx.A_star * ctx.Y0 + ctx.X_star
59 def backward(ctx, grad_output):
60 U = grad_output * ctx.A_star
64 Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
65 return (Q * R).sum(-1), R / ctx.A_star, U
70 ######################################################################
72 if __name__ == "__main__":
73 # Iterative implementation
75 A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
76 X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
77 Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
81 for k in range(A.size(1)):
82 y = A[:, k, None] * y + X[:, k]
85 print(torch.autograd.grad(y.mean(), A, retain_graph=True))
86 print(torch.autograd.grad(y.mean(), X, retain_graph=True))
87 print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
95 for k in range(A.size(1)):
96 print(f"{k} -> {Y[:,k]}")
100 print(torch.autograd.grad(y.mean(), A, retain_graph=True))
101 print(torch.autograd.grad(y.mean(), X, retain_graph=True))
102 print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))