Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:07:04 +0000 (18:07 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:07:04 +0000 (18:07 +0100)
maxval.py

index 747f4b9..a202b17 100755 (executable)
--- a/maxval.py
+++ b/maxval.py
@@ -25,13 +25,14 @@ def baseline(X, V):
 
 def pscan(X, V, s=1):
     if X.size(1) == 1:
-        return X, V
+        return
 
     T = 2 * (X.size(1) // 2)
 
     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
 
+    # [:, :, 0] < [:, :, 1]
     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
     Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
     m = m[:, :, None]
@@ -39,12 +40,48 @@ def pscan(X, V, s=1):
 
     pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
 
-    m = (Vf[:, 1:, 0] >= Vf[:, :-1, 1] - s).long()
-    Vf[:, 1:, 0] = m * Vf[:, 1:, 0] + (1 - m) * (Vf[:, :-1, 1] - s)
+    # [:, :-1, 1] < [:, 1:, 0]
+    m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+    Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
     m = m[:, :, None]
-    Xf[:, 1:, 0] = m * Xf[:, 1:, 0] + (1 - m) * Xf[:, :-1, 1]
+    Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
 
     if T < X.size(1):
+        # [:, -2] < [:, -1]
+        m = (V[:, -2] - s >= V[:, -1]).long()
+        V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+        m = m[:, None]
+        X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+
+######################################################################
+
+
+def pscan_diff(X, V, s=1):
+    if X.size(1) == 1:
+        return
+
+    T = 2 * (X.size(1) // 2)
+
+    Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
+    Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+
+    # [:, :, 0] < [:, :, 1]
+    m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
+    Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+    m = m[:, :, None]
+    Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+    pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+
+    # [:, :-1, 1] < [:, 1:, 0]
+    m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+    Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+    m = m[:, :, None]
+    Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+
+    if T < X.size(1):
+        # [:, -2] < [:, -1]
         m = (V[:, -2] - s >= V[:, -1]).long()
         V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
         m = m[:, None]
@@ -58,23 +95,25 @@ if __name__ == "__main__":
     T = 513
     D = 2
 
-    X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
-    # V = torch.rand(N, T, dtype=torch.float64) * 50
+    X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+    V = torch.rand(N, T, dtype=torch.float64) * 10
 
-    X0, V0 = baseline(X, V)
+    X0, V0 = baseline(X, V)
 
     # print("########### X0 V0 ###########################################")
     # print(V0)
     # print(X0)
 
-    X1, V1 = X.clone(), V.clone()
-    # pscan(X1, V1)
+    X1, V1 = X.clone(), V.clone()
+    pscan_diff(X1, V1)
 
     # print("########### X V ############################################")
     # print(V)
     # print(X)
 
-    # print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
+    print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
+
+    exit(0)
 
     # s = X1.sum()
     # print(torch.autograd.grad(s, X))
@@ -93,7 +132,7 @@ if __name__ == "__main__":
 
     for k in range(1000):
         X1, V1 = X.clone(), V.clone()
-        pscan(X1, V1)
+        pscan(X, V, X1, V1)
         # X1=X1*(1+V1-V1.detach())[:,:,None]
         loss = (X1[:, -1:] - Y).pow(2).mean()
         print(k, loss.item())