From 29cd6ffe24dfbc5720efe8b123ec1973d868881a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 20 Mar 2023 09:11:38 +0100 Subject: [PATCH] Update --- maze.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/maze.py b/maze.py index 44bef7c..36eef25 100755 --- a/maze.py +++ b/maze.py @@ -61,11 +61,11 @@ def create_maze(h=11, w=17, nb_walls=8): ###################################################################### -def compute_distance(walls, i, j): +def compute_distance(walls, goal_i, goal_j): max_length = walls.numel() dist = torch.full_like(walls, max_length) - dist[i, j] = 0 + dist[goal_i, goal_j] = 0 pred_dist = torch.empty_like(dist) while True: @@ -93,8 +93,8 @@ def compute_distance(walls, i, j): ###################################################################### -def compute_policy(walls, i, j): - distance = compute_distance(walls, i, j) +def compute_policy(walls, goal_i, goal_j): + distance = compute_distance(walls, goal_i, goal_j) distance = distance + walls.numel() * walls value = distance.new_full((4,) + distance.size(), walls.numel()) @@ -110,6 +110,21 @@ def compute_policy(walls, i, j): return proba +def stationary_density(policy, start_i, start_j): + probas = policy.new_zeros(policy.size()[:-1]) + pred_probas = probas.clone() + probas[start_i, start_j] = 1.0 + + while not pred_probas.equal(probas): + pred_probas.copy_(probas) + probas.zero_() + probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :] + probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :] + probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1] + probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:] + probas[start_i, start_j] = 1.0 + + ###################################################################### -- 2.39.5