X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maxval.py;h=31aa4b6698b7d6f27721f5349c3610eb7a2d59a1;hb=afd42d993de96b1a5dd31ff57f63a48fdfc00243;hp=747f4b9c336a83e8e39f54b1937e3f14e27c8598;hpb=6ca2c05c7470e92cd591fe1b8de33c80c1b27180;p=mygptrnn.git diff --git a/maxval.py b/maxval.py index 747f4b9..31aa4b6 100755 --- a/maxval.py +++ b/maxval.py @@ -25,13 +25,14 @@ def baseline(X, V): def pscan(X, V, s=1): if X.size(1) == 1: - return X, V + return T = 2 * (X.size(1) // 2) Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2)) Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2) + # [:, :, 0] < [:, :, 1] m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] m = m[:, :, None] @@ -39,18 +40,64 @@ def pscan(X, V, s=1): pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2) - m = (Vf[:, 1:, 0] >= Vf[:, :-1, 1] - s).long() - Vf[:, 1:, 0] = m * Vf[:, 1:, 0] + (1 - m) * (Vf[:, :-1, 1] - s) + # [:, :-1, 1] < [:, 1:, 0] + m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() + Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] m = m[:, :, None] - Xf[:, 1:, 0] = m * Xf[:, 1:, 0] + (1 - m) * Xf[:, :-1, 1] + Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] if T < X.size(1): + # [:, -2] < [:, -1] m = (V[:, -2] - s >= V[:, -1]).long() V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] m = m[:, None] X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] +###################################################################### + + +def pscan_diff(X, V, s=1): + if X.size(1) == 1: + return X, V + + T = 2 * (X.size(1) // 2) + + Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2)) + Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2) + + Xr = X.new(X.size()) + Vr = V.new(V.size()) + Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2)) + Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2) + + # [:, :, 0] < [:, :, 1] + m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() + Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] + m = m[:, :, None] + Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1] + + Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2) + + Xr[:, 0] = X[:, 0] + Vr[:, 0] = V[:, 0] + + # [:, :-1, 1] < [:, 1:, 0] + m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() + Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] + m = m[:, :, None] + Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] + + if T < X.size(1): + # [:, -2] < [:, -1] + m = (V[:, -2] - s >= V[:, -1]).long() + Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] + m = m[:, None] + Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] + + return Xr, Vr + + ###################################################################### if __name__ == "__main__": @@ -58,23 +105,24 @@ if __name__ == "__main__": T = 513 D = 2 - # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - # V = torch.rand(N, T, dtype=torch.float64) * 50 + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + V = torch.rand(N, T, dtype=torch.float64) * 10 - # X0, V0 = baseline(X, V) + X0, V0 = baseline(X, V) # print("########### X0 V0 ###########################################") # print(V0) # print(X0) - # X1, V1 = X.clone(), V.clone() - # pscan(X1, V1) + X1, V1 = pscan_diff(X, V) # print("########### X V ############################################") # print(V) # print(X) - # print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()) + print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()) + + exit(0) # s = X1.sum() # print(torch.autograd.grad(s, X)) @@ -93,7 +141,7 @@ if __name__ == "__main__": for k in range(1000): X1, V1 = X.clone(), V.clone() - pscan(X1, V1) + pscan(X, V, X1, V1) # X1=X1*(1+V1-V1.detach())[:,:,None] loss = (X1[:, -1:] - Y).pow(2).mean() print(k, loss.item())