-def naive_rec(A, X, Y0):
- Y = []
- for t in range(X.size(1)):
- if t == 0:
- Y.append(A[:, t] * Y0 + X[:, t])
- else:
- Y.append(A[:, t] * Y[-1] + X[:, t])
-
- return torch.cat([y[:, None, :] for y in Y], dim=1)
-
-
-######################################################################
-
-# A is NxTx1 and X is NxTxD
-#
-# Returns Y defined with
-#
-# Y[:, 0] = A[:, 0] * Y0 + X[:,0]
-# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t]
-
-
-def pscan_rec(A, X, Y0):
- if X.size(1) % 2 == 1:
+class PScan(torch.autograd.Function):
+ # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+ # and O(log(T)) if not core-bounded, so that
+ #
+ # Y[:, 0] = Y_init
+ # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+ #
+ # can be computed as
+ #
+ # Y[:, t] = A[:, t] * Y_init + X[:, t]
+
+ @staticmethod
+ def expand(A, X):
+ if A.size(1) == 1:
+ return
+ T = 2 * (A.size(1) // 2)
+ Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
+ Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
+ PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
+ Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+ Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+ if T < A.size(1):
+ X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+ A[:, -1].mul_(A[:, -2])
+
+ # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
+
+ @staticmethod
+ def accrev(X):