- if A.size(1) == 1:
- return
- T = 2 * (A.size(1) // 2)
- Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
- Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
- Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
- Aa[:, :, 1].mul_(Aa[:, :, 0])
- PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
- Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
- Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
- if T < A.size(1):
- X[:, -1].add_(A[:, -1].mul(X[:, -2]))
- A[:, -1].mul_(A[:, -2])
+ # Unrolling gains ~8% speed
+
+ if A.size(1) > 4:
+ T = 2 * (A.size(1) // 2)
+ Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
+ Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
+ PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
+ Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+ Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+ if T < A.size(1):
+ X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+ A[:, -1].mul_(A[:, -2])
+ elif A.size(1) == 2:
+ X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+ A[:, 1].mul_(A[:, 0])
+ elif A.size(1) == 3:
+ X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+ A[:, 1].mul_(A[:, 0])
+ X[:, 2].add_(A[:, 2].mul(X[:, 1]))
+ A[:, 2].mul_(A[:, 1])
+ elif A.size(1) == 4:
+ X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+ A[:, 1].mul_(A[:, 0])
+ X[:, 2].add_(A[:, 2].mul(X[:, 1]))
+ A[:, 2].mul_(A[:, 1])
+ X[:, 3].add_(A[:, 3].mul(X[:, 2]))
+ A[:, 3].mul_(A[:, 2])