X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maze.py;h=e377d2f0e6b2ab18684c93b6d85f7c6093cc67ea;hb=27bb2d1ab23422f26b05f88b4e0573deeb075cd2;hp=2c4431916d1609ba6611996c257165b4626420a8;hpb=3602bfe2c4e1cd513759bf45cb83f8c2d914674b;p=beaver.git diff --git a/maze.py b/maze.py index 2c44319..e377d2f 100755 --- a/maze.py +++ b/maze.py @@ -129,14 +129,12 @@ def mark_path(walls, i, j, goal_i, goal_j): assert n < nmax -def valid_paths(mazes, paths): +def path_correctness(mazes, paths): still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0 reached = still_ok.new_zeros(still_ok.size()) current, pred_current = paths.clone(), paths.new_zeros(paths.size()) goal = (mazes == v_goal).long() while not pred_current.equal(current): - # print(current) - # print(f'{still_ok=} {reached=}') pred_current.copy_(current) u = (current == v_start).long() possible_next = ( @@ -157,12 +155,14 @@ def valid_paths(mazes, paths): ###################################################################### -def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1): - mazes = torch.empty(nb, h, w, dtype=torch.int64) - paths = torch.empty(nb, h, w, dtype=torch.int64) +def create_maze_data( + nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x +): + mazes = torch.empty(nb, height, width, dtype=torch.int64) + paths = torch.empty(nb, height, width, dtype=torch.int64) - for n in range(nb): - maze = create_maze(h, w, nb_walls) + for n in progress_bar(range(nb)): + maze = create_maze(height, width, nb_walls) i = (1 - maze).nonzero() while True: start, goal = i[torch.randperm(i.size(0))[:2]] @@ -185,8 +185,8 @@ def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1): ###################################################################### -def save_image(name, mazes, paths): - mazes, paths = mazes.cpu(), paths.cpu() +def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None): + mazes, target_paths = mazes.cpu(), target_paths.cpu() colors = torch.tensor( [ @@ -199,20 +199,53 @@ def save_image(name, mazes, paths): ) mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) - paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2) - - img = torch.cat((mazes.unsqueeze(1), paths.unsqueeze(1)), 1) - img = img.reshape((-1,) + img.size()[2:]).float() / 255.0 - - torchvision.utils.save_image(img, name, padding=1, pad_value=0.5, nrow=8) + target_paths = ( + colors[target_paths.reshape(-1)] + .reshape(target_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) + + if predicted_paths is not None: + predicted_paths = predicted_paths.cpu() + predicted_paths = ( + colors[predicted_paths.reshape(-1)] + .reshape(predicted_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1) + + # NxKxCxHxW + if path_correct is None: + path_correct = torch.zeros(imgs.size(0)) <= 1 + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor( + [255, 0, 0] + ).view(1, -1, 1, 1) * (1 - path_correct) + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) ###################################################################### if __name__ == "__main__": - mazes, paths = create_maze_data(32, dist_min=10) - save_image("test.png", mazes, paths) - print(valid_paths(mazes, paths)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mazes, paths = create_maze_data(8) + mazes, paths = mazes.to(device), paths.to(device) + save_image("test.png", mazes, paths, paths) + print(path_correctness(mazes, paths)) ######################################################################