if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- mazes, paths = create_maze_data(8)
+ mazes, paths, policies = create_maze_data(8)
mazes, paths = mazes.to(device), paths.to(device)
- save_image("test.png", mazes, paths, paths)
+ save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
print(path_correctness(mazes, paths))
######################################################################