def create_maze(h=11, w=17, nb_walls=8):
+ assert h % 2 == 1 and w % 2 == 1
+
a, k = 0, 0
while k < nb_walls:
)
imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
- img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1)
+ img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1)
# NxKxCxHxW
if path_optimal is not None:
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- mazes, paths = create_maze_data(8)
+ mazes, paths, policies = create_maze_data(8)
mazes, paths = mazes.to(device), paths.to(device)
- save_image("test.png", mazes, paths, paths)
+ save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
print(path_correctness(mazes, paths))
######################################################################