5 ##################################################
9 a = x.max(-1, keepdim=True)
10 i = torch.arange(x.size(-1) - 1)[None, :]
13 (i < a.indices) * (x - a.values)[:, :-1]
14 + (i >= a.indices) * (a.values - x)[:, 1:],
23 u = torch.nn.functional.pad(y, (1, -1))
25 (y < 0) * (y[:, -1:] + y)
26 + (y >= 0) * (u < 0) * (y[:, -1:])
27 + (y >= 0) * (u >= 0) * (y[:, -1:] - u)
32 ##################################################
34 x = torch.randn(3, 14)
36 print(f"{x.size()=} {x.max(-1).values=}")
37 print(f"{y.size()=} {y[:,-1]=}")
40 print(f"{(z-x).abs().max()=}")