Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index d11ab6e..6c3fe94 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -158,11 +158,11 @@ def create_maze_data(
 ):
     mazes = torch.empty(nb, height, width, dtype=torch.int64)
     paths = torch.empty(nb, height, width, dtype=torch.int64)
-    policies = torch.empty(nb, 4, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width)
 
     for n in progress_bar(range(nb)):
         maze = create_maze(height, width, nb_walls)
-        i = (1 - maze).nonzero()
+        i = (maze == v_empty).nonzero()
         while True:
             start, goal = i[torch.randperm(i.size(0))[:2]]
             if (start - goal).abs().sum() >= dist_min: