Update.
[pytorch.git] / pscan.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def naive_rec(A, X, Y0):
14     Y = []
15     for t in range(X.size(1)):
16         if t == 0:
17             Y.append(A[:, t] * Y0 + X[:, t])
18         else:
19             Y.append(A[:, t] * Y[-1] + X[:, t])
20
21     return torch.cat([y[:, None, :] for y in Y], dim=1)
22
23
24 ######################################################################
25
26 # A is NxTx1 and X is NxTxD
27 #
28 # Returns Y defined with
29 #
30 #           Y[:, 0] = A[:, 0] * Y0 + X[:,0]
31 # for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t]
32
33
34 def pscan_rec(A, X, Y0):
35     if X.size(1) % 2 == 1:
36         if X.size(1) == 1:
37             return A[:, :1] * Y0[:, None] + X[:, :1]
38         else:
39             Y = pscan_rec(A[:, :-1], X[:, :-1], Y0)
40             return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1)
41
42     A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2))
43     X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2))
44
45     X_star = X2[:, :, 0].clone()
46     X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1]
47
48     A_star = A2[:, :, 0].clone()
49     A_star[:, 1:] *= A2[:, :-1, 1]
50
51     Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None]
52
53     Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2)
54
55     Y = Y.reshape(Y.size(0), -1, Y.size(-1))
56
57     return Y
58
59
60 ######################################################################
61
62 N, T, D = 5, 29, 12
63
64 A = torch.rand(N, T, 1, dtype=torch.float64)
65 X = torch.randint(10, (N, T, D), dtype=torch.float64)
66 Y0 = torch.randint(10, (N, D), dtype=torch.float64)
67
68 naive_Y = naive_rec(A, X, Y0)
69
70 pscan_Y = pscan_rec(A, X, Y0)
71
72 print((naive_Y - pscan_Y).pow(2).mean())
73
74 pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0)
75 pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1])
76
77 print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean())