5 ######################################################################
11 for t in range(X.size(1)):
16 m = (V[:, t] >= W[:, t - 1] - 1).long()
17 Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
18 W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
23 ######################################################################
30 T = 2 * (X.size(1) // 2)
32 Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
33 Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
35 # [:, :, 0] < [:, :, 1]
36 m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
37 Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
39 Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
41 pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
43 # [:, :-1, 1] < [:, 1:, 0]
44 m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
45 Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
47 Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
51 m = (V[:, -2] - s >= V[:, -1]).long()
52 V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
54 X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
57 ######################################################################
60 def pscan_diff(X, V, s=1):
64 T = 2 * (X.size(1) // 2)
66 Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
67 Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
69 # [:, :, 0] < [:, :, 1]
70 m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
71 Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
73 Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
75 pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
77 # [:, :-1, 1] < [:, 1:, 0]
78 m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
79 Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
81 Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
85 m = (V[:, -2] - s >= V[:, -1]).long()
86 V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
88 X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
91 ######################################################################
93 if __name__ == "__main__":
98 X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
99 V = torch.rand(N, T, dtype=torch.float64) * 10
101 X0, V0 = baseline(X, V)
103 # print("########### X0 V0 ###########################################")
107 X1, V1 = X.clone(), V.clone()
110 # print("########### X V ############################################")
114 print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
119 # print(torch.autograd.grad(s, X))
121 # with open("/tmp/v.dat", "w") as f:
123 # f.write(f"{V1[0,t].item()}\n")
125 Y = torch.randn(1, 1, D)
128 ) # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y
129 V = torch.rand(N, T).requires_grad_()
131 optimizer = torch.optim.SGD([V], lr=1e-2)
133 for k in range(1000):
134 X1, V1 = X.clone(), V.clone()
136 # X1=X1*(1+V1-V1.detach())[:,:,None]
137 loss = (X1[:, -1:] - Y).pow(2).mean()
138 print(k, loss.item())
139 optimizer.zero_grad()