Update.
[pytorch.git] / pscan.py
1 #!/usr/bin/env python
2
3 import torch
4
5 ######################################################################
6
7
8 class PScan(torch.autograd.Function):
9     # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
10     # and O(log(T)) if not core-bounded, so that
11     #
12     # Y[:, 0] = Y0
13     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
14     #
15     # can be computed as
16     #
17     # Y[:, t] = A[:, t] * Y0 + X[:, t]
18
19     @staticmethod
20     def expand(A, X):
21         if A.size(1) == 1:
22             return
23         T = 2 * (A.size(1) // 2)
24         Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
25         Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
26         Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
27         Aa[:, :, 1].mul_(Aa[:, :, 0])
28         PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
29         Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
30         Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
31         if T < A.size(1):
32             X[:, -1].add_(A[:, -1].mul(X[:, -2]))
33             A[:, -1].mul_(A[:, -2])
34
35     # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
36
37     @staticmethod
38     def accrev(X):
39         if X.size(1) == 1:
40             return
41         T = 2 * (X.size(1) // 2)
42         Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
43         Xa[:, :, 0].add_(Xa[:, :, 1])
44         PScan.accrev(Xa[:, :, 0])
45         Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
46         if T < X.size(1):
47             X[:, 0].add_(X[:, 1])
48
49     @staticmethod
50     def forward(ctx, A, X, Y0):
51         ctx.A = A[:, :, None].clone()
52         ctx.Y0 = Y0[:, None, :].clone()
53         ctx.A_star = A[:, :, None].clone()
54         ctx.X_star = X.clone()
55         PScan.expand(ctx.A_star, ctx.X_star)
56         return ctx.A_star * ctx.Y0 + ctx.X_star
57
58     @staticmethod
59     def backward(ctx, grad_output):
60         U = grad_output * ctx.A_star
61         R = U.clone()
62         PScan.accrev(R)
63         Q = ctx.Y0 / ctx.A
64         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
65         return (Q * R).sum(-1), R / ctx.A_star, U
66
67
68 pscan = PScan.apply
69
70 ######################################################################
71
72 if __name__ == "__main__":
73     # Iterative implementation
74
75     A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
76     X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
77     Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
78
79     y = Y0[:, None]
80
81     for k in range(A.size(1)):
82         y = A[:, k, None] * y + X[:, k]
83         print(f"{k} -> {y}")
84
85     print(torch.autograd.grad(y.mean(), A, retain_graph=True))
86     print(torch.autograd.grad(y.mean(), X, retain_graph=True))
87     print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
88
89     print()
90
91     # parallel scan
92
93     Y = pscan(A, X, Y0)
94
95     for k in range(A.size(1)):
96         print(f"{k} -> {Y[:,k]}")
97
98     y = Y[:, -1]
99
100     print(torch.autograd.grad(y.mean(), A, retain_graph=True))
101     print(torch.autograd.grad(y.mean(), X, retain_graph=True))
102     print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))