From b079055ebf75af45b4d61306630e9094b6787342 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 4 Aug 2024 10:55:09 +0200 Subject: [PATCH] Update. --- main.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 4dca41f..7361ae8 100755 --- a/main.py +++ b/main.py @@ -364,14 +364,13 @@ def sigma_for_grids(input): l = input.size(1) // 4 - 1 sigma = input.new(input.size()) r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4) - r[:, :, 0] = 0 - r[:, :, 1:] = ( + r[:, 0] = 0 * l + r[:, 1] = 1 * l + r[:, 2] = 2 * l + r[:, 3] = 3 * l + r[:, :, 1:] += ( torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices ) + 1 - r[:, 0] += 0 * l - r[:, 1] += 1 * l - r[:, 2] += 2 * l - r[:, 3] += 3 * l return sigma -- 2.39.5