Update.
[pytorch.git] / pscan.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9
10 ######################################################################
11
12
13 class PScan(torch.autograd.Function):
14     # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
15     # and O(log(T)) if not core-bounded, so that
16     #
17     # Y[:, 0] = Y0
18     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
19     #
20     # can be computed as
21     #
22     # Y[:, t] = A[:, t] * Y0 + X[:, t]
23
24     @staticmethod
25     def expand(A, X):
26         if A.size(1) == 1:
27             return
28         T = 2 * (A.size(1) // 2)
29         Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
30         Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
31         Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
32         Aa[:, :, 1].mul_(Aa[:, :, 0])
33         PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
34         Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
35         Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
36         if T < A.size(1):
37             X[:, -1].add_(A[:, -1].mul(X[:, -2]))
38             A[:, -1].mul_(A[:, -2])
39
40     # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
41
42     @staticmethod
43     def accrev(X):
44         if X.size(1) == 1:
45             return
46         T = 2 * (X.size(1) // 2)
47         Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
48         Xa[:, :, 0].add_(Xa[:, :, 1])
49         PScan.accrev(Xa[:, :, 0])
50         Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
51         if T < X.size(1):
52             X[:, 0].add_(X[:, 1])
53
54     @staticmethod
55     def forward(ctx, A, X, Y0):
56         ctx.A = A[:, :, None].clone()
57         ctx.Y0 = Y0[:, None, :].clone()
58         ctx.A_star = A[:, :, None].clone()
59         ctx.X_star = X.clone()
60         PScan.expand(ctx.A_star, ctx.X_star)
61         return ctx.A_star * ctx.Y0 + ctx.X_star
62
63     @staticmethod
64     def backward(ctx, grad_output):
65         U = grad_output * ctx.A_star
66         R = U.clone()
67         PScan.accrev(R)
68         Q = ctx.Y0 / ctx.A
69         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
70         return (Q * R).sum(-1), R / ctx.A_star, U
71
72
73 pscan = PScan.apply
74
75 ######################################################################
76
77 if __name__ == "__main__":
78     # Iterative implementation
79
80     A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
81     X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
82     Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
83
84     y = Y0[:, None]
85
86     for k in range(A.size(1)):
87         y = A[:, k, None] * y + X[:, k]
88         print(f"{k} -> {y}")
89
90     print(torch.autograd.grad(y.mean(), A, retain_graph=True))
91     print(torch.autograd.grad(y.mean(), X, retain_graph=True))
92     print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
93
94     print()
95
96     # parallel scan
97
98     Y = pscan(A, X, Y0)
99
100     for k in range(A.size(1)):
101         print(f"{k} -> {Y[:,k]}")
102
103     y = Y[:, -1]
104
105     print(torch.autograd.grad(y.mean(), A, retain_graph=True))
106     print(torch.autograd.grad(y.mean(), X, retain_graph=True))
107     print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))