Update
authorFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 08:11:38 +0000 (09:11 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 20 Mar 2023 08:11:38 +0000 (09:11 +0100)
maze.py

diff --git a/maze.py b/maze.py
index 44bef7c..36eef25 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -61,11 +61,11 @@ def create_maze(h=11, w=17, nb_walls=8):
 ######################################################################
 
 
-def compute_distance(walls, i, j):
+def compute_distance(walls, goal_i, goal_j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
-    dist[i, j] = 0
+    dist[goal_i, goal_j] = 0
     pred_dist = torch.empty_like(dist)
 
     while True:
@@ -93,8 +93,8 @@ def compute_distance(walls, i, j):
 ######################################################################
 
 
-def compute_policy(walls, i, j):
-    distance = compute_distance(walls, i, j)
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
@@ -110,6 +110,21 @@ def compute_policy(walls, i, j):
     return proba
 
 
+def stationary_density(policy, start_i, start_j):
+    probas = policy.new_zeros(policy.size()[:-1])
+    pred_probas = probas.clone()
+    probas[start_i, start_j] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :]
+        probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :]
+        probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1]
+        probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:]
+        probas[start_i, start_j] = 1.0
+
+
 ######################################################################