5 ######################################################################
11 for t in range(X.size(1)):
16 m = (V[:, t] >= W[:, t - 1] - 1).long()
17 Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
18 W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
23 ######################################################################
30 T = 2 * (X.size(1) // 2)
32 Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
33 Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
35 m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
36 Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
38 Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
40 pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
42 m = (Vf[:, 1:, 0] >= Vf[:, :-1, 1] - s).long()
43 Vf[:, 1:, 0] = m * Vf[:, 1:, 0] + (1 - m) * (Vf[:, :-1, 1] - s)
45 Xf[:, 1:, 0] = m * Xf[:, 1:, 0] + (1 - m) * Xf[:, :-1, 1]
48 m = (V[:, -2] - s >= V[:, -1]).long()
49 V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
51 X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
54 ######################################################################
56 if __name__ == "__main__":
61 # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
62 # V = torch.rand(N, T, dtype=torch.float64) * 50
64 # X0, V0 = baseline(X, V)
66 # print("########### X0 V0 ###########################################")
70 # X1, V1 = X.clone(), V.clone()
73 # print("########### X V ############################################")
77 # print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
80 # print(torch.autograd.grad(s, X))
82 # with open("/tmp/v.dat", "w") as f:
84 # f.write(f"{V1[0,t].item()}\n")
86 Y = torch.randn(1, 1, D)
89 ) # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y
90 V = torch.rand(N, T).requires_grad_()
92 optimizer = torch.optim.SGD([V], lr=1e-2)
95 X1, V1 = X.clone(), V.clone()
97 # X1=X1*(1+V1-V1.detach())[:,:,None]
98 loss = (X1[:, -1:] - Y).pow(2).mean()
100 optimizer.zero_grad()