Update.
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 2c44319..f4a4840 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -129,14 +129,12 @@ def mark_path(walls, i, j, goal_i, goal_j):
         assert n < nmax
 
 
-def valid_paths(mazes, paths):
+def path_correctness(mazes, paths):
     still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0
     reached = still_ok.new_zeros(still_ok.size())
     current, pred_current = paths.clone(), paths.new_zeros(paths.size())
     goal = (mazes == v_goal).long()
     while not pred_current.equal(current):
-        # print(current)
-        # print(f'{still_ok=} {reached=}')
         pred_current.copy_(current)
         u = (current == v_start).long()
         possible_next = (
@@ -157,12 +155,14 @@ def valid_paths(mazes, paths):
 ######################################################################
 
 
-def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1):
-    mazes = torch.empty(nb, h, w, dtype=torch.int64)
-    paths = torch.empty(nb, h, w, dtype=torch.int64)
+def create_maze_data(
+    nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
+):
+    mazes = torch.empty(nb, height, width, dtype=torch.int64)
+    paths = torch.empty(nb, height, width, dtype=torch.int64)
 
-    for n in range(nb):
-        maze = create_maze(h, w, nb_walls)
+    for n in progress_bar(range(nb)):
+        maze = create_maze(height, width, nb_walls)
         i = (1 - maze).nonzero()
         while True:
             start, goal = i[torch.randperm(i.size(0))[:2]]
@@ -185,8 +185,8 @@ def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1):
 ######################################################################
 
 
-def save_image(name, mazes, paths):
-    mazes, paths = mazes.cpu(), paths.cpu()
+def save_image(name, mazes, target_paths, predicted_paths=None):
+    mazes, target_paths = mazes.cpu(), target_paths.cpu()
 
     colors = torch.tensor(
         [
@@ -199,20 +199,35 @@ def save_image(name, mazes, paths):
     )
 
     mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2)
+    target_paths = (
+        colors[target_paths.reshape(-1)]
+        .reshape(target_paths.size() + (-1,))
+        .permute(0, 3, 1, 2)
+    )
+    img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+
+    if predicted_paths is not None:
+        predicted_paths = predicted_paths.cpu()
+        predicted_paths = (
+            colors[predicted_paths.reshape(-1)]
+            .reshape(predicted_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+        img = torch.cat((img, predicted_paths.unsqueeze(1)), 1)
 
-    img = torch.cat((mazes.unsqueeze(1), paths.unsqueeze(1)), 1)
     img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
 
-    torchvision.utils.save_image(img, name, padding=1, pad_value=0.5, nrow=8)
+    torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6)
 
 
 ######################################################################
 
 if __name__ == "__main__":
 
-    mazes, paths = create_maze_data(32, dist_min=10)
-    save_image("test.png", mazes, paths)
-    print(valid_paths(mazes, paths))
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    mazes, paths = create_maze_data(8)
+    mazes, paths = mazes.to(device), paths.to(device)
+    save_image("test.png", mazes, paths, paths)
+    print(path_correctness(mazes, paths))
 
 ######################################################################