Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index d09e860..81afcd9 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -98,10 +98,10 @@ def compute_policy(walls, goal_i, goal_j):
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]
-    value[1, :, :-1] = distance[:, 1:]
-    value[2, 1:, :] = distance[:-1, :]
-    value[3, :-1, :] = distance[1:, :]
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
 
     proba = (value.min(dim=0)[0][None] == value).float()
     proba = proba / proba.sum(dim=0)[None]
@@ -111,18 +111,19 @@ def compute_policy(walls, goal_i, goal_j):
 
 
 def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
     start = (mazes == v_start).nonzero(as_tuple=True)
-    probas = mazes.new_zeros(mazes.size())
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
     pred_probas = probas.clone()
     probas[start] = 1.0
 
     while not pred_probas.equal(probas):
         pred_probas.copy_(probas)
         probas.zero_()
-        probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :]
-        probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :]
-        probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1]
-        probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:]
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
         probas[start] = 1.0
 
     return probas
@@ -211,6 +212,7 @@ def save_image(
     target_paths=None,
     predicted_paths=None,
     score_paths=None,
+    score_truth=None,
     path_correct=None,
 ):
     colors = torch.tensor(
@@ -229,6 +231,17 @@ def save_image(
         colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
     )
 
+    if score_truth is not None:
+        score_truth = score_truth.cpu()
+        c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_truth = (
+            c_score_truth * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+            mazes.unsqueeze(1) == v_empty
+        ) * c_score_truth
+
     imgs = c_mazes.unsqueeze(1)
 
     if target_paths is not None: