Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 6c3fe94..6e8e179 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -187,9 +187,14 @@ def create_maze_data(
 ######################################################################
 
 
-def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
-    mazes, target_paths = mazes.cpu(), target_paths.cpu()
-
+def save_image(
+    name,
+    mazes,
+    target_paths=None,
+    predicted_paths=None,
+    score_paths=None,
+    path_correct=None,
+):
     colors = torch.tensor(
         [
             [255, 255, 255],  # empty
@@ -200,13 +205,22 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
         ]
     )
 
+    mazes = mazes.cpu()
+
     mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    target_paths = (
-        colors[target_paths.reshape(-1)]
-        .reshape(target_paths.size() + (-1,))
-        .permute(0, 3, 1, 2)
-    )
-    imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+
+    imgs = mazes.unsqueeze(1)
+
+    if target_paths is not None:
+        target_paths = target_paths.cpu()
+
+        target_paths = (
+            colors[target_paths.reshape(-1)]
+            .reshape(target_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+
+        imgs = torch.cat((imgs, target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
@@ -217,6 +231,11 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
         )
         imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
 
+    if score_paths is not None:
+        score_paths = (score_paths.cpu() * 255.0).long()
+        score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        imgs = torch.cat((imgs, score_paths.unsqueeze(1)), 1)
+
     # NxKxCxHxW
     if path_correct is None:
         path_correct = torch.zeros(imgs.size(0)) <= 1