Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index d11ab6e..6e8e179 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -158,11 +158,11 @@ def create_maze_data(
 ):
     mazes = torch.empty(nb, height, width, dtype=torch.int64)
     paths = torch.empty(nb, height, width, dtype=torch.int64)
-    policies = torch.empty(nb, 4, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width)
 
     for n in progress_bar(range(nb)):
         maze = create_maze(height, width, nb_walls)
-        i = (1 - maze).nonzero()
+        i = (maze == v_empty).nonzero()
         while True:
             start, goal = i[torch.randperm(i.size(0))[:2]]
             if (start - goal).abs().sum() >= dist_min:
@@ -187,9 +187,14 @@ def create_maze_data(
 ######################################################################
 
 
-def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
-    mazes, target_paths = mazes.cpu(), target_paths.cpu()
-
+def save_image(
+    name,
+    mazes,
+    target_paths=None,
+    predicted_paths=None,
+    score_paths=None,
+    path_correct=None,
+):
     colors = torch.tensor(
         [
             [255, 255, 255],  # empty
@@ -200,13 +205,22 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
         ]
     )
 
+    mazes = mazes.cpu()
+
     mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    target_paths = (
-        colors[target_paths.reshape(-1)]
-        .reshape(target_paths.size() + (-1,))
-        .permute(0, 3, 1, 2)
-    )
-    imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+
+    imgs = mazes.unsqueeze(1)
+
+    if target_paths is not None:
+        target_paths = target_paths.cpu()
+
+        target_paths = (
+            colors[target_paths.reshape(-1)]
+            .reshape(target_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+
+        imgs = torch.cat((imgs, target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
@@ -217,6 +231,11 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
         )
         imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
 
+    if score_paths is not None:
+        score_paths = (score_paths.cpu() * 255.0).long()
+        score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        imgs = torch.cat((imgs, score_paths.unsqueeze(1)), 1)
+
     # NxKxCxHxW
     if path_correct is None:
         path_correct = torch.zeros(imgs.size(0)) <= 1