X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maze.py;fp=maze.py;h=d09e860b4c500587ce6f1eba5f873fecf5812aca;hb=d63c681fdb2d6b5590991eaa4a2d9a5376678c67;hp=36eef25540b5b52a9b5688c12573b1bcd03bf6f2;hpb=29cd6ffe24dfbc5720efe8b123ec1973d868881a;p=beaver.git diff --git a/maze.py b/maze.py index 36eef25..d09e860 100755 --- a/maze.py +++ b/maze.py @@ -110,19 +110,22 @@ def compute_policy(walls, goal_i, goal_j): return proba -def stationary_density(policy, start_i, start_j): - probas = policy.new_zeros(policy.size()[:-1]) +def stationary_densities(mazes, policies): + start = (mazes == v_start).nonzero(as_tuple=True) + probas = mazes.new_zeros(mazes.size()) pred_probas = probas.clone() - probas[start_i, start_j] = 1.0 + probas[start] = 1.0 while not pred_probas.equal(probas): pred_probas.copy_(probas) probas.zero_() - probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :] - probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :] - probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1] - probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:] - probas[start_i, start_j] = 1.0 + probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :] + probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :] + probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1] + probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:] + probas[start] = 1.0 + + return probas ######################################################################