##################################################
-x = torch.randn(3, 14)
+# x = torch.randn(3, 14)
+
+# Checking that ambiguous max does not hurt
+x = torch.randint(21, (100, 50)) - 10
+
y = rmax(x)
-print(f"{x.size()=} {x.max(-1).values=}")
-print(f"{y.size()=} {y[:,-1]=}")
+
+print(f"{x.size()=} {y.size()=}")
+print(f"{(x.max(-1).values - y[:,-1]).abs().max()=}")
z = rmax_back(y)
+
print(f"{(z-x).abs().max()=}")