Update.
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index cfdede3..d11ab6e 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -113,18 +113,16 @@ def compute_policy(walls, i, j):
 ######################################################################
 
 
-def mark_path(walls, i, j, goal_i, goal_j):
-    policy = compute_policy(walls, goal_i, goal_j)
+def mark_path(walls, i, j, goal_i, goal_j, policy):
     action = torch.distributions.categorical.Categorical(
         policy.permute(1, 2, 0)
     ).sample()
-    walls[i, j] = 4
     n, nmax = 0, walls.numel()
     while i != goal_i or j != goal_j:
         di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
         i, j = i + di, j + dj
         assert walls[i, j] == 0
-        walls[i, j] = 4
+        walls[i, j] = v_path
         n += 1
         assert n < nmax
 
@@ -160,6 +158,7 @@ def create_maze_data(
 ):
     mazes = torch.empty(nb, height, width, dtype=torch.int64)
     paths = torch.empty(nb, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width, dtype=torch.int64)
 
     for n in progress_bar(range(nb)):
         maze = create_maze(height, width, nb_walls)
@@ -168,18 +167,21 @@ def create_maze_data(
             start, goal = i[torch.randperm(i.size(0))[:2]]
             if (start - goal).abs().sum() >= dist_min:
                 break
+        start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
 
+        policy = compute_policy(maze, goal_i, goal_j)
         path = maze.clone()
-        mark_path(path, start[0], start[1], goal[0], goal[1])
-        maze[start[0], start[1]] = v_start
-        maze[goal[0], goal[1]] = v_goal
-        path[start[0], start[1]] = v_start
-        path[goal[0], goal[1]] = v_goal
+        mark_path(path, start_i, start_j, goal_i, goal_j, policy)
+        maze[start_i, start_j] = v_start
+        maze[goal_i, goal_j] = v_goal
+        path[start_i, start_j] = v_start
+        path[goal_i, goal_j] = v_goal
 
         mazes[n] = maze
         paths[n] = path
+        policies[n] = policy
 
-    return mazes, paths
+    return mazes, paths, policies
 
 
 ######################################################################