Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index d11ab6e..754cdea 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -158,11 +158,11 @@ def create_maze_data(
 ):
     mazes = torch.empty(nb, height, width, dtype=torch.int64)
     paths = torch.empty(nb, height, width, dtype=torch.int64)
-    policies = torch.empty(nb, 4, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width)
 
     for n in progress_bar(range(nb)):
         maze = create_maze(height, width, nb_walls)
-        i = (1 - maze).nonzero()
+        i = (maze == v_empty).nonzero()
         while True:
             start, goal = i[torch.randperm(i.size(0))[:2]]
             if (start - goal).abs().sum() >= dist_min:
@@ -187,9 +187,14 @@ def create_maze_data(
 ######################################################################
 
 
-def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
-    mazes, target_paths = mazes.cpu(), target_paths.cpu()
-
+def save_image(
+    name,
+    mazes,
+    target_paths=None,
+    predicted_paths=None,
+    score_paths=None,
+    path_correct=None,
+):
     colors = torch.tensor(
         [
             [255, 255, 255],  # empty
@@ -200,22 +205,45 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
         ]
     )
 
-    mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    target_paths = (
-        colors[target_paths.reshape(-1)]
-        .reshape(target_paths.size() + (-1,))
-        .permute(0, 3, 1, 2)
+    mazes = mazes.cpu()
+
+    c_mazes = (
+        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
     )
-    imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+
+    imgs = c_mazes.unsqueeze(1)
+
+    if target_paths is not None:
+        target_paths = target_paths.cpu()
+
+        c_target_paths = (
+            colors[target_paths.reshape(-1)]
+            .reshape(target_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+
+        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
-        predicted_paths = (
+        c_predicted_paths = (
             colors[predicted_paths.reshape(-1)]
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
-        imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+    if score_paths is not None:
+        score_paths = score_paths.cpu()
+        c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_paths = (
+            c_score_paths * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_paths) * colors[3].reshape(1, 3, 1, 1)
+        ).long()
+        c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+            mazes.unsqueeze(1) != v_empty
+        )
+        imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
 
     # NxKxCxHxW
     if path_correct is None: