From: François Fleuret Date: Sat, 27 Jan 2024 17:45:49 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=afd42d993de96b1a5dd31ff57f63a48fdfc00243;p=mygptrnn.git Update. --- diff --git a/maxval.py b/maxval.py index a202b17..31aa4b6 100755 --- a/maxval.py +++ b/maxval.py @@ -59,33 +59,43 @@ def pscan(X, V, s=1): def pscan_diff(X, V, s=1): if X.size(1) == 1: - return + return X, V T = 2 * (X.size(1) // 2) Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2)) Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2) + Xr = X.new(X.size()) + Vr = V.new(V.size()) + Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2)) + Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2) + # [:, :, 0] < [:, :, 1] m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() - Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] + Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] m = m[:, :, None] - Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1] + Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1] + + Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2) - pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2) + Xr[:, 0] = X[:, 0] + Vr[:, 0] = V[:, 0] # [:, :-1, 1] < [:, 1:, 0] - m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() - Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] + m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() + Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] m = m[:, :, None] - Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] + Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] if T < X.size(1): # [:, -2] < [:, -1] m = (V[:, -2] - s >= V[:, -1]).long() - V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] + Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] m = m[:, None] - X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] + Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] + + return Xr, Vr ###################################################################### @@ -104,8 +114,7 @@ if __name__ == "__main__": # print(V0) # print(X0) - X1, V1 = X.clone(), V.clone() - pscan_diff(X1, V1) + X1, V1 = pscan_diff(X, V) # print("########### X V ############################################") # print(V)