def pscan(X, V, s=1):
if X.size(1) == 1:
- return X, V
+ return
T = 2 * (X.size(1) // 2)
Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+ # [:, :, 0] < [:, :, 1]
m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
m = m[:, :, None]
pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
- m = (Vf[:, 1:, 0] >= Vf[:, :-1, 1] - s).long()
- Vf[:, 1:, 0] = m * Vf[:, 1:, 0] + (1 - m) * (Vf[:, :-1, 1] - s)
+ # [:, :-1, 1] < [:, 1:, 0]
+ m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+ Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
m = m[:, :, None]
- Xf[:, 1:, 0] = m * Xf[:, 1:, 0] + (1 - m) * Xf[:, :-1, 1]
+ Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
if T < X.size(1):
+ # [:, -2] < [:, -1]
+ m = (V[:, -2] - s >= V[:, -1]).long()
+ V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+ m = m[:, None]
+ X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+
+######################################################################
+
+
+def pscan_diff(X, V, s=1):
+ if X.size(1) == 1:
+ return
+
+ T = 2 * (X.size(1) // 2)
+
+ Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
+ Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+
+ # [:, :, 0] < [:, :, 1]
+ m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
+ Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+ m = m[:, :, None]
+ Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+ pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+
+ # [:, :-1, 1] < [:, 1:, 0]
+ m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+ Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+ m = m[:, :, None]
+ Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+
+ if T < X.size(1):
+ # [:, -2] < [:, -1]
m = (V[:, -2] - s >= V[:, -1]).long()
V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
m = m[:, None]
T = 513
D = 2
- # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
- # V = torch.rand(N, T, dtype=torch.float64) * 50
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ V = torch.rand(N, T, dtype=torch.float64) * 10
- # X0, V0 = baseline(X, V)
+ X0, V0 = baseline(X, V)
# print("########### X0 V0 ###########################################")
# print(V0)
# print(X0)
- # X1, V1 = X.clone(), V.clone()
- # pscan(X1, V1)
+ X1, V1 = X.clone(), V.clone()
+ pscan_diff(X1, V1)
# print("########### X V ############################################")
# print(V)
# print(X)
- # print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
+ print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
+
+ exit(0)
# s = X1.sum()
# print(torch.autograd.grad(s, X))
for k in range(1000):
X1, V1 = X.clone(), V.clone()
- pscan(X1, V1)
+ pscan(X, V, X1, V1)
# X1=X1*(1+V1-V1.detach())[:,:,None]
loss = (X1[:, -1:] - Y).pow(2).mean()
print(k, loss.item())