return proba
-def stationary_density(policy, start_i, start_j):
- probas = policy.new_zeros(policy.size()[:-1])
+def stationary_densities(mazes, policies):
+ start = (mazes == v_start).nonzero(as_tuple=True)
+ probas = mazes.new_zeros(mazes.size())
pred_probas = probas.clone()
- probas[start_i, start_j] = 1.0
+ probas[start] = 1.0
while not pred_probas.equal(probas):
pred_probas.copy_(probas)
probas.zero_()
- probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :]
- probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :]
- probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1]
- probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:]
- probas[start_i, start_j] = 1.0
+ probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
+ probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
+ probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
+ probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+ probas[start] = 1.0
+
+ return probas
######################################################################