3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 ######################################################################
13 class PScan(torch.autograd.Function):
14 # Given A is NxTxMx1 and X is NxTxMxD, expands A and X in
15 # place in O(T), and O(log(T)) if not core-bounded, so that
18 # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
22 # Y[:, t] = A[:, t] * Y_init + X[:, t]
26 # Unrolling gains ~8% speed
29 T = 2 * (A.size(1) // 2)
30 Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
31 Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
32 Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
33 Aa[:, :, 1].mul_(Aa[:, :, 0])
34 PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
35 Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
36 Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
38 X[:, -1].add_(A[:, -1].mul(X[:, -2]))
39 A[:, -1].mul_(A[:, -2])
41 X[:, 1].add_(A[:, 1].mul(X[:, 0]))
44 X[:, 1].add_(A[:, 1].mul(X[:, 0]))
46 X[:, 2].add_(A[:, 2].mul(X[:, 1]))
49 X[:, 1].add_(A[:, 1].mul(X[:, 0]))
51 X[:, 2].add_(A[:, 2].mul(X[:, 1]))
53 X[:, 3].add_(A[:, 3].mul(X[:, 2]))
59 T = 2 * (X.size(1) // 2)
60 Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
61 Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
62 Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
63 B = Aa[:, :, 0].clone()
64 B[:, 1:].mul_(Aa[:, :-1, 1])
65 PScan.acc_rev_(B, Xa[:, :, 0])
66 Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
68 X[:, 0].add_(A[:, 1].mul(X[:, 1]))
70 X[:, 0].add_(A[:, 1].mul(X[:, 1]))
72 X[:, 1].add_(A[:, 2].mul(X[:, 2]))
73 X[:, 0].add_(A[:, 1].mul(X[:, 1]))
75 X[:, 2].add_(A[:, 3].mul(X[:, 3]))
76 X[:, 1].add_(A[:, 2].mul(X[:, 2]))
77 X[:, 0].add_(A[:, 1].mul(X[:, 1]))
79 # A is NxT, X is NxTxD, Y_init is NxD
81 # returns Y of same shape as X, with
83 # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0
84 # = A[:, t] * Y[:, t-1] + X[:, t] otherwise
87 def forward(ctx, A, X, Y_init):
88 ctx.A = A.unsqueeze(-1).clone()
89 ctx.Y_init = Y_init[:, None].clone()
90 ctx.A_star = ctx.A.clone()
91 ctx.X_star = X.clone()
92 PScan.expand_(ctx.A_star, ctx.X_star)
93 return ctx.A_star * ctx.Y_init + ctx.X_star
96 def backward(ctx, grad_output):
97 U = grad_output * ctx.A_star
99 R = grad_output.clone()
101 Q = ctx.Y_init.expand_as(ctx.X_star).clone()
102 Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1])
103 return (Q * R).sum(-1), R, U.sum(dim=1)
109 def naive_pscan(A, X, Y_init):
113 for k in range(A.size(1)):
114 y = A[:, k, None] * y + X[:, k]
120 ######################################################################
122 if __name__ == "__main__":
125 # A = torch.rand(17, 12, 3)
126 # X = torch.rand(17, 12, 3, 11)
127 # Y_init = torch.rand(17, 3, 11)
128 # Y = pscan(A, X, Y_init)
137 T = torch.randint(10, (1,)).item() + 1
139 A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
140 X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
141 Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
143 # Iterative implementation
148 for k in range(A.size(1)):
149 y = A[:, k, None] * y + X[:, k]
154 gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
155 s, (A, X, Y_init), retain_graph=True
160 start_time = time.perf_counter()
161 for _ in range(1000):
162 Y = pscan(A, X, Y_init)
163 duration = time.perf_counter() - start_time
164 print(f"duration {duration}")
168 gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
173 (gA - gA_ref).abs().max(),
174 (gX - gX_ref).abs().max(),
175 (gY_init - gY_init_ref).abs().max(),
179 # Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
180 # Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
182 # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max())