Update.
[pytorch.git] / pscan.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9
10 ######################################################################
11
12
13 class PScan(torch.autograd.Function):
14     # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
15     # and O(log(T)) if not core-bounded, so that
16     #
17     # Y[:, 0] = Y0
18     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
19     #
20     # can be computed as
21     #
22     # Y[:, t] = A[:, t] * Y0 + X[:, t]
23
24     @staticmethod
25     def expand(A, X):
26         if A.size(1) == 1:
27             return
28         T = 2 * (A.size(1) // 2)
29         Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
30         Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
31         Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
32         Aa[:, :, 1].mul_(Aa[:, :, 0])
33         PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
34         Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
35         Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
36         if T < A.size(1):
37             X[:, -1].add_(A[:, -1].mul(X[:, -2]))
38             A[:, -1].mul_(A[:, -2])
39
40     # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
41
42     @staticmethod
43     def accrev(X):
44         if X.size(1) == 1:
45             return
46         T = 2 * (X.size(1) // 2)
47         Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
48         Xa[:, :, 0].add_(Xa[:, :, 1])
49         PScan.accrev(Xa[:, :, 0])
50         Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
51         if T < X.size(1):
52             X[:, 0].add_(X[:, 1])
53
54     @staticmethod
55     def forward(ctx, A, X, Y0):
56         ctx.A = A[:, :, None].clone()
57         ctx.Y0 = Y0[:, None, :].clone()
58         ctx.A_star = A[:, :, None].clone()
59         ctx.X_star = X.clone()
60         PScan.expand(ctx.A_star, ctx.X_star)
61         return ctx.A_star * ctx.Y0 + ctx.X_star
62
63     @staticmethod
64     def backward(ctx, grad_output):
65         U = grad_output * ctx.A_star
66         R = U.clone()
67         PScan.accrev(R)
68         Q = ctx.Y0 / ctx.A
69         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
70         return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1)
71
72
73 pscan = PScan.apply
74
75 ######################################################################
76
77 if __name__ == "__main__":
78     N, T, D = 2, 5, 3
79
80     # Iterative implementation
81
82     A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
83     X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
84     Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
85
86     y = Y0
87     s = 0
88
89     for k in range(A.size(1)):
90         y = A[:, k, None] * y + X[:, k]
91         s = s + y
92         # print(f"{k} -> {y}")
93
94     s = s.sum()
95
96     # print(s)
97     print(torch.autograd.grad(s, A, retain_graph=True))
98     print(torch.autograd.grad(s, X, retain_graph=True))
99     print(torch.autograd.grad(s, Y0, retain_graph=True))
100
101     print()
102
103     # parallel scan
104
105     Y = pscan(A, X, Y0)
106
107     # for k in range(A.size(1)):
108     # print(f"{k} -> {Y[:,k]}")
109
110     s = Y.sum()
111
112     # print(s)
113     print(torch.autograd.grad(s, A, retain_graph=True))
114     print(torch.autograd.grad(s, X, retain_graph=True))
115     print(torch.autograd.grad(s, Y0, retain_graph=True))