def pscan_diff(X, V, s=1):
if X.size(1) == 1:
- return
+ return X, V
T = 2 * (X.size(1) // 2)
Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+ Xr = X.new(X.size())
+ Vr = V.new(V.size())
+ Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
+ Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
+
# [:, :, 0] < [:, :, 1]
m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
- Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+ Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
m = m[:, :, None]
- Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+ Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+ Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
- pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+ Xr[:, 0] = X[:, 0]
+ Vr[:, 0] = V[:, 0]
# [:, :-1, 1] < [:, 1:, 0]
- m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
- Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+ m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+ Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
m = m[:, :, None]
- Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+ Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
if T < X.size(1):
# [:, -2] < [:, -1]
m = (V[:, -2] - s >= V[:, -1]).long()
- V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+ Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
m = m[:, None]
- X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+ Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+ return Xr, Vr
######################################################################
# print(V0)
# print(X0)
- X1, V1 = X.clone(), V.clone()
- pscan_diff(X1, V1)
+ X1, V1 = pscan_diff(X, V)
# print("########### X V ############################################")
# print(V)