Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 36eef25..81afcd9 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -98,10 +98,10 @@ def compute_policy(walls, goal_i, goal_j):
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]
-    value[1, :, :-1] = distance[:, 1:]
-    value[2, 1:, :] = distance[:-1, :]
-    value[3, :-1, :] = distance[1:, :]
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
 
     proba = (value.min(dim=0)[0][None] == value).float()
     proba = proba / proba.sum(dim=0)[None]
@@ -110,19 +110,23 @@ def compute_policy(walls, goal_i, goal_j):
     return proba
 
 
-def stationary_density(policy, start_i, start_j):
-    probas = policy.new_zeros(policy.size()[:-1])
+def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
     pred_probas = probas.clone()
-    probas[start_i, start_j] = 1.0
+    probas[start] = 1.0
 
     while not pred_probas.equal(probas):
         pred_probas.copy_(probas)
         probas.zero_()
-        probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :]
-        probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :]
-        probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1]
-        probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:]
-        probas[start_i, start_j] = 1.0
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+        probas[start] = 1.0
+
+    return probas
 
 
 ######################################################################
@@ -208,6 +212,7 @@ def save_image(
     target_paths=None,
     predicted_paths=None,
     score_paths=None,
+    score_truth=None,
     path_correct=None,
 ):
     colors = torch.tensor(
@@ -226,6 +231,17 @@ def save_image(
         colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
     )
 
+    if score_truth is not None:
+        score_truth = score_truth.cpu()
+        c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_truth = (
+            c_score_truth * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+            mazes.unsqueeze(1) == v_empty
+        ) * c_score_truth
+
     imgs = c_mazes.unsqueeze(1)
 
     if target_paths is not None: