Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 36eef25..d09e860 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -110,19 +110,22 @@ def compute_policy(walls, goal_i, goal_j):
     return proba
 
 
-def stationary_density(policy, start_i, start_j):
-    probas = policy.new_zeros(policy.size()[:-1])
+def stationary_densities(mazes, policies):
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size())
     pred_probas = probas.clone()
-    probas[start_i, start_j] = 1.0
+    probas[start] = 1.0
 
     while not pred_probas.equal(probas):
         pred_probas.copy_(probas)
         probas.zero_()
-        probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :]
-        probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :]
-        probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1]
-        probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:]
-        probas[start_i, start_j] = 1.0
+        probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
+        probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
+        probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
+        probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+        probas[start] = 1.0
+
+    return probas
 
 
 ######################################################################