From: François Fleuret Date: Sat, 27 Jan 2024 17:50:30 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6d23462ce76c9020dcd7c4bc8a0e7a0fae9b7971;p=mygptrnn.git Update. --- diff --git a/maxval.py b/maxval.py index 31aa4b6..a245287 100755 --- a/maxval.py +++ b/maxval.py @@ -72,10 +72,12 @@ def pscan_diff(X, V, s=1): Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2) # [:, :, 0] < [:, :, 1] + dx = Xf[:, :, 1] - Xf[:, :, 1].detach() + dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None] m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] m = m[:, :, None] - Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1] + Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx) Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2) @@ -83,17 +85,21 @@ def pscan_diff(X, V, s=1): Vr[:, 0] = V[:, 0] # [:, :-1, 1] < [:, 1:, 0] + dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach() + dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None] m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] m = m[:, :, None] - Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] + Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx) if T < X.size(1): # [:, -2] < [:, -1] + dx = X[:, -1] - X[:, -1].detach() + dv = (V[:, -1] - V[:, -1].detach())[:, None] m = (V[:, -2] - s >= V[:, -1]).long() Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] m = m[:, None] - Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] + Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx) return Xr, Vr