X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maze.py;h=6c3fe94fe189159a4a450fb90c33f02d73cb6e92;hb=7143cc544d2b0af03150d9ee05f3cf21319c693b;hp=d11ab6ef177fbef75dcd354e38da45e4df4f717f;hpb=39e24a2f9076db2d512791e723e7f2dc0275d99c;p=beaver.git diff --git a/maze.py b/maze.py index d11ab6e..6c3fe94 100755 --- a/maze.py +++ b/maze.py @@ -158,11 +158,11 @@ def create_maze_data( ): mazes = torch.empty(nb, height, width, dtype=torch.int64) paths = torch.empty(nb, height, width, dtype=torch.int64) - policies = torch.empty(nb, 4, height, width, dtype=torch.int64) + policies = torch.empty(nb, 4, height, width) for n in progress_bar(range(nb)): maze = create_maze(height, width, nb_walls) - i = (1 - maze).nonzero() + i = (maze == v_empty).nonzero() while True: start, goal = i[torch.randperm(i.size(0))[:2]] if (start - goal).abs().sum() >= dist_min: