Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:45:49 +0000 (18:45 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:45:49 +0000 (18:45 +0100)
maxval.py

index a202b17..31aa4b6 100755 (executable)
--- a/maxval.py
+++ b/maxval.py
@@ -59,33 +59,43 @@ def pscan(X, V, s=1):
 
 def pscan_diff(X, V, s=1):
     if X.size(1) == 1:
-        return
+        return X, V
 
     T = 2 * (X.size(1) // 2)
 
     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
 
+    Xr = X.new(X.size())
+    Vr = V.new(V.size())
+    Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
+    Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
+
     # [:, :, 0] < [:, :, 1]
     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
-    Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+    Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
     m = m[:, :, None]
-    Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+    Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+    Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
 
-    pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+    Xr[:, 0] = X[:, 0]
+    Vr[:, 0] = V[:, 0]
 
     # [:, :-1, 1] < [:, 1:, 0]
-    m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
-    Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+    m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+    Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
     m = m[:, :, None]
-    Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+    Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
 
     if T < X.size(1):
         # [:, -2] < [:, -1]
         m = (V[:, -2] - s >= V[:, -1]).long()
-        V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+        Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
         m = m[:, None]
-        X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+        Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+    return Xr, Vr
 
 
 ######################################################################
@@ -104,8 +114,7 @@ if __name__ == "__main__":
     # print(V0)
     # print(X0)
 
-    X1, V1 = X.clone(), V.clone()
-    pscan_diff(X1, V1)
+    X1, V1 = pscan_diff(X, V)
 
     # print("########### X V ############################################")
     # print(V)