Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 44bef7c..d09e860 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -61,11 +61,11 @@ def create_maze(h=11, w=17, nb_walls=8):
 ######################################################################
 
 
 ######################################################################
 
 
-def compute_distance(walls, i, j):
+def compute_distance(walls, goal_i, goal_j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
-    dist[i, j] = 0
+    dist[goal_i, goal_j] = 0
     pred_dist = torch.empty_like(dist)
 
     while True:
     pred_dist = torch.empty_like(dist)
 
     while True:
@@ -93,8 +93,8 @@ def compute_distance(walls, i, j):
 ######################################################################
 
 
 ######################################################################
 
 
-def compute_policy(walls, i, j):
-    distance = compute_distance(walls, i, j)
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
@@ -110,6 +110,24 @@ def compute_policy(walls, i, j):
     return proba
 
 
     return proba
 
 
+def stationary_densities(mazes, policies):
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size())
+    pred_probas = probas.clone()
+    probas[start] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
+        probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
+        probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
+        probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+        probas[start] = 1.0
+
+    return probas
+
+
 ######################################################################
 
 
 ######################################################################