From c7475c52af263fbfe367797682979525532e4d7e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 May 2023 07:21:20 +0200 Subject: [PATCH] Update --- rmax.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/rmax.py b/rmax.py index 291ce92..b51a5bf 100755 --- a/rmax.py +++ b/rmax.py @@ -31,10 +31,16 @@ def rmax_back(y): ################################################## -x = torch.randn(3, 14) +# x = torch.randn(3, 14) + +# Checking that ambiguous max does not hurt +x = torch.randint(21, (100, 50)) - 10 + y = rmax(x) -print(f"{x.size()=} {x.max(-1).values=}") -print(f"{y.size()=} {y[:,-1]=}") + +print(f"{x.size()=} {y.size()=}") +print(f"{(x.max(-1).values - y[:,-1]).abs().max()=}") z = rmax_back(y) + print(f"{(z-x).abs().max()=}") -- 2.39.5