projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
maze.py
diff --git
a/maze.py
b/maze.py
index
fd0a1d2
..
f6715f0
100755
(executable)
--- a/
maze.py
+++ b/
maze.py
@@
-13,6
+13,8
@@
v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
def create_maze(h=11, w=17, nb_walls=8):
def create_maze(h=11, w=17, nb_walls=8):
+ assert h % 2 == 1 and w % 2 == 1
+
a, k = 0, 0
while k < nb_walls:
a, k = 0, 0
while k < nb_walls:
@@
-285,7
+287,7
@@
def save_image(
)
imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
)
imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
- img = torch.tensor([2
24, 224, 224
]).view(1, -1, 1, 1)
+ img = torch.tensor([2
55, 255, 0
]).view(1, -1, 1, 1)
# NxKxCxHxW
if path_optimal is not None:
# NxKxCxHxW
if path_optimal is not None:
@@
-322,9
+324,9
@@
def save_image(
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- mazes, paths = create_maze_data(8)
+ mazes, paths
, policies
= create_maze_data(8)
mazes, paths = mazes.to(device), paths.to(device)
mazes, paths = mazes.to(device), paths.to(device)
- save_image("test.png", mazes
, paths,
paths)
+ save_image("test.png", mazes
=mazes, target_paths=paths, predicted_paths=
paths)
print(path_correctness(mazes, paths))
######################################################################
print(path_correctness(mazes, paths))
######################################################################