5 ######################################################################
11 for t in range(X.size(1)):
16 m = (V[:, t] >= W[:, t - 1] - 1).long()
17 Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
18 W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
23 ######################################################################
30 T = 2 * (X.size(1) // 2)
32 Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
33 Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
35 # [:, :, 0] < [:, :, 1]
36 m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
37 Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
39 Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
41 pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
43 # [:, :-1, 1] < [:, 1:, 0]
44 m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
45 Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
47 Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
51 m = (V[:, -2] - s >= V[:, -1]).long()
52 V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
54 X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
57 ######################################################################
60 def pscan_diff(X, V, s=1):
64 T = 2 * (X.size(1) // 2)
66 Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
67 Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
71 Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
72 Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
74 # [:, :, 0] < [:, :, 1]
75 dx = Xf[:, :, 1] - Xf[:, :, 1].detach()
76 dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None]
77 m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
78 Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
80 Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx)
82 Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
87 # [:, :-1, 1] < [:, 1:, 0]
88 dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach()
89 dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None]
90 m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
91 Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
93 Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx)
97 dx = X[:, -1] - X[:, -1].detach()
98 dv = (V[:, -1] - V[:, -1].detach())[:, None]
99 m = (V[:, -2] - s >= V[:, -1]).long()
100 Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
102 Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
107 ######################################################################
109 if __name__ == "__main__":
114 X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
115 V = torch.rand(N, T, dtype=torch.float64) * 10
117 X0, V0 = baseline(X, V)
119 # print("########### X0 V0 ###########################################")
123 X1, V1 = pscan_diff(X, V)
125 # print("########### X V ############################################")
129 print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
134 # print(torch.autograd.grad(s, X))
136 # with open("/tmp/v.dat", "w") as f:
138 # f.write(f"{V1[0,t].item()}\n")
140 Y = torch.randn(1, 1, D)
143 ) # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y
144 V = torch.rand(N, T).requires_grad_()
146 optimizer = torch.optim.SGD([V], lr=1e-2)
148 for k in range(1000):
149 X1, V1 = X.clone(), V.clone()
151 # X1=X1*(1+V1-V1.detach())[:,:,None]
152 loss = (X1[:, -1:] - Y).pow(2).mean()
153 print(k, loss.item())
154 optimizer.zero_grad()