Update
authorFrançois Fleuret <francois@fleuret.org>
Fri, 17 Mar 2023 16:40:57 +0000 (17:40 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 17 Mar 2023 16:40:57 +0000 (17:40 +0100)
maze.py

diff --git a/maze.py b/maze.py
index 6e8e179..754cdea 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -207,34 +207,43 @@ def save_image(
 
     mazes = mazes.cpu()
 
-    mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    c_mazes = (
+        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    )
 
-    imgs = mazes.unsqueeze(1)
+    imgs = c_mazes.unsqueeze(1)
 
     if target_paths is not None:
         target_paths = target_paths.cpu()
 
-        target_paths = (
+        c_target_paths = (
             colors[target_paths.reshape(-1)]
             .reshape(target_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
 
-        imgs = torch.cat((imgs, target_paths.unsqueeze(1)), 1)
+        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
-        predicted_paths = (
+        c_predicted_paths = (
             colors[predicted_paths.reshape(-1)]
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
-        imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
 
     if score_paths is not None:
-        score_paths = (score_paths.cpu() * 255.0).long()
-        score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
-        imgs = torch.cat((imgs, score_paths.unsqueeze(1)), 1)
+        score_paths = score_paths.cpu()
+        c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_paths = (
+            c_score_paths * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_paths) * colors[3].reshape(1, 3, 1, 1)
+        ).long()
+        c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+            mazes.unsqueeze(1) != v_empty
+        )
+        imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
 
     # NxKxCxHxW
     if path_correct is None: