Update.
[pytorch.git] / rmax.py
diff --git a/rmax.py b/rmax.py
new file mode 100755 (executable)
index 0000000..291ce92
--- /dev/null
+++ b/rmax.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+
+import torch
+
+##################################################
+
+
+def rmax(x):
+    a = x.max(-1, keepdim=True)
+    i = torch.arange(x.size(-1) - 1)[None, :]
+    y = torch.cat(
+        (
+            (i < a.indices) * (x - a.values)[:, :-1]
+            + (i >= a.indices) * (a.values - x)[:, 1:],
+            a.values,
+        ),
+        -1,
+    )
+    return y
+
+
+def rmax_back(y):
+    u = torch.nn.functional.pad(y, (1, -1))
+    x = (
+        (y < 0) * (y[:, -1:] + y)
+        + (y >= 0) * (u < 0) * (y[:, -1:])
+        + (y >= 0) * (u >= 0) * (y[:, -1:] - u)
+    )
+    return x
+
+
+##################################################
+
+x = torch.randn(3, 14)
+y = rmax(x)
+print(f"{x.size()=} {x.max(-1).values=}")
+print(f"{y.size()=} {y[:,-1]=}")
+
+z = rmax_back(y)
+print(f"{(z-x).abs().max()=}")