From 703baba8b6619c57562e1d622be75c8b409659e5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 7 Jul 2023 23:31:06 +0200 Subject: [PATCH] Update. --- maze.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/maze.py b/maze.py index f6715f0..8ac9fce 100755 --- a/maze.py +++ b/maze.py @@ -221,8 +221,6 @@ def save_image( mazes, target_paths=None, predicted_paths=None, - score_paths=None, - score_truth=None, path_correct=None, path_optimal=None, ): @@ -242,17 +240,6 @@ def save_image( colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) ) - if score_truth is not None: - score_truth = score_truth.cpu() - c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1) - c_score_truth = ( - c_score_truth * colors[4].reshape(1, 3, 1, 1) - + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1) - ).long() - c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + ( - mazes.unsqueeze(1) == v_empty - ) * c_score_truth - imgs = c_mazes.unsqueeze(1) if target_paths is not None: @@ -275,18 +262,6 @@ def save_image( ) imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1) - if score_paths is not None: - score_paths = score_paths.cpu() - c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1) - c_score_paths = ( - c_score_paths * colors[4].reshape(1, 3, 1, 1) - + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1) - ).long() - c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * ( - mazes.unsqueeze(1) != v_empty - ) - imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1) - img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1) # NxKxCxHxW @@ -307,6 +282,8 @@ def save_image( -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) ).clone() + print(f"{img.size()=} {imgs.size()=}") + for k in range(imgs.size(1)): img[ :, -- 2.20.1