Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 2c44319..81afcd9 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -61,11 +61,11 @@ def create_maze(h=11, w=17, nb_walls=8):
 ######################################################################
 
 
-def compute_distance(walls, i, j):
+def compute_distance(walls, goal_i, goal_j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
-    dist[i, j] = 0
+    dist[goal_i, goal_j] = 0
     pred_dist = torch.empty_like(dist)
 
     while True:
@@ -93,15 +93,15 @@ def compute_distance(walls, i, j):
 ######################################################################
 
 
-def compute_policy(walls, i, j):
-    distance = compute_distance(walls, i, j)
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]
-    value[1, :, :-1] = distance[:, 1:]
-    value[2, 1:, :] = distance[:-1, :]
-    value[3, :-1, :] = distance[1:, :]
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
 
     proba = (value.min(dim=0)[0][None] == value).float()
     proba = proba / proba.sum(dim=0)[None]
@@ -110,33 +110,48 @@ def compute_policy(walls, i, j):
     return proba
 
 
+def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
+    pred_probas = probas.clone()
+    probas[start] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+        probas[start] = 1.0
+
+    return probas
+
+
 ######################################################################
 
 
-def mark_path(walls, i, j, goal_i, goal_j):
-    policy = compute_policy(walls, goal_i, goal_j)
+def mark_path(walls, i, j, goal_i, goal_j, policy):
     action = torch.distributions.categorical.Categorical(
         policy.permute(1, 2, 0)
     ).sample()
-    walls[i, j] = 4
     n, nmax = 0, walls.numel()
     while i != goal_i or j != goal_j:
         di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
         i, j = i + di, j + dj
         assert walls[i, j] == 0
-        walls[i, j] = 4
+        walls[i, j] = v_path
         n += 1
         assert n < nmax
 
 
-def valid_paths(mazes, paths):
+def path_correctness(mazes, paths):
     still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0
     reached = still_ok.new_zeros(still_ok.size())
     current, pred_current = paths.clone(), paths.new_zeros(paths.size())
     goal = (mazes == v_goal).long()
     while not pred_current.equal(current):
-        # print(current)
-        # print(f'{still_ok=} {reached=}')
         pred_current.copy_(current)
         u = (current == v_start).long()
         possible_next = (
@@ -157,62 +172,140 @@ def valid_paths(mazes, paths):
 ######################################################################
 
 
-def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1):
-    mazes = torch.empty(nb, h, w, dtype=torch.int64)
-    paths = torch.empty(nb, h, w, dtype=torch.int64)
+def create_maze_data(
+    nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
+):
+    mazes = torch.empty(nb, height, width, dtype=torch.int64)
+    paths = torch.empty(nb, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width)
 
-    for n in range(nb):
-        maze = create_maze(h, w, nb_walls)
-        i = (1 - maze).nonzero()
+    for n in progress_bar(range(nb)):
+        maze = create_maze(height, width, nb_walls)
+        i = (maze == v_empty).nonzero()
         while True:
             start, goal = i[torch.randperm(i.size(0))[:2]]
             if (start - goal).abs().sum() >= dist_min:
                 break
+        start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
 
+        policy = compute_policy(maze, goal_i, goal_j)
         path = maze.clone()
-        mark_path(path, start[0], start[1], goal[0], goal[1])
-        maze[start[0], start[1]] = v_start
-        maze[goal[0], goal[1]] = v_goal
-        path[start[0], start[1]] = v_start
-        path[goal[0], goal[1]] = v_goal
+        mark_path(path, start_i, start_j, goal_i, goal_j, policy)
+        maze[start_i, start_j] = v_start
+        maze[goal_i, goal_j] = v_goal
+        path[start_i, start_j] = v_start
+        path[goal_i, goal_j] = v_goal
 
         mazes[n] = maze
         paths[n] = path
+        policies[n] = policy
 
-    return mazes, paths
+    return mazes, paths, policies
 
 
 ######################################################################
 
 
-def save_image(name, mazes, paths):
-    mazes, paths = mazes.cpu(), paths.cpu()
-
+def save_image(
+    name,
+    mazes,
+    target_paths=None,
+    predicted_paths=None,
+    score_paths=None,
+    score_truth=None,
+    path_correct=None,
+):
     colors = torch.tensor(
         [
             [255, 255, 255],  # empty
             [0, 0, 0],  # wall
             [0, 255, 0],  # start
-            [0, 0, 255],  # goal
+            [127, 127, 255],  # goal
             [255, 0, 0],  # path
         ]
     )
 
-    mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2)
+    mazes = mazes.cpu()
 
-    img = torch.cat((mazes.unsqueeze(1), paths.unsqueeze(1)), 1)
-    img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
+    c_mazes = (
+        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    )
+
+    if score_truth is not None:
+        score_truth = score_truth.cpu()
+        c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_truth = (
+            c_score_truth * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+            mazes.unsqueeze(1) == v_empty
+        ) * c_score_truth
 
-    torchvision.utils.save_image(img, name, padding=1, pad_value=0.5, nrow=8)
+    imgs = c_mazes.unsqueeze(1)
+
+    if target_paths is not None:
+        target_paths = target_paths.cpu()
+
+        c_target_paths = (
+            colors[target_paths.reshape(-1)]
+            .reshape(target_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+
+        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
+
+    if predicted_paths is not None:
+        predicted_paths = predicted_paths.cpu()
+        c_predicted_paths = (
+            colors[predicted_paths.reshape(-1)]
+            .reshape(predicted_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+    if score_paths is not None:
+        score_paths = score_paths.cpu()
+        c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_paths = (
+            c_score_paths * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+            mazes.unsqueeze(1) != v_empty
+        )
+        imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
+
+    # NxKxCxHxW
+    if path_correct is None:
+        path_correct = torch.zeros(imgs.size(0)) <= 1
+    path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+    img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+        [255, 0, 0]
+    ).view(1, -1, 1, 1) * (1 - path_correct)
+    img = img.expand(
+        -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+    ).clone()
+    for k in range(imgs.size(1)):
+        img[
+            :,
+            :,
+            1 : 1 + imgs.size(3),
+            1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+        ] = imgs[:, k]
+
+    img = img.float() / 255.0
+
+    torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
 
 
 ######################################################################
 
 if __name__ == "__main__":
-
-    mazes, paths = create_maze_data(32, dist_min=10)
-    save_image("test.png", mazes, paths)
-    print(valid_paths(mazes, paths))
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    mazes, paths = create_maze_data(8)
+    mazes, paths = mazes.to(device), paths.to(device)
+    save_image("test.png", mazes, paths, paths)
+    print(path_correctness(mazes, paths))
 
 ######################################################################