+def stationary_densities(mazes, policies):
+ start = (mazes == v_start).nonzero(as_tuple=True)
+ probas = mazes.new_zeros(mazes.size())
+ pred_probas = probas.clone()
+ probas[start] = 1.0
+
+ while not pred_probas.equal(probas):
+ pred_probas.copy_(probas)
+ probas.zero_()
+ probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
+ probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
+ probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
+ probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+ probas[start] = 1.0
+
+ return probas
+
+