X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maze.py;h=e377d2f0e6b2ab18684c93b6d85f7c6093cc67ea;hb=27bb2d1ab23422f26b05f88b4e0573deeb075cd2;hp=f4a4840c352d26406e9bd80077ef4cf407f9bc95;hpb=d44d0605fed828b8cea08c8e1c5bda7e4528ea97;p=beaver.git diff --git a/maze.py b/maze.py index f4a4840..e377d2f 100755 --- a/maze.py +++ b/maze.py @@ -185,7 +185,7 @@ def create_maze_data( ###################################################################### -def save_image(name, mazes, target_paths, predicted_paths=None): +def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None): mazes, target_paths = mazes.cpu(), target_paths.cpu() colors = torch.tensor( @@ -204,7 +204,7 @@ def save_image(name, mazes, target_paths, predicted_paths=None): .reshape(target_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) - img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) + imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) if predicted_paths is not None: predicted_paths = predicted_paths.cpu() @@ -213,11 +213,29 @@ def save_image(name, mazes, target_paths, predicted_paths=None): .reshape(predicted_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) - img = torch.cat((img, predicted_paths.unsqueeze(1)), 1) - - img = img.reshape((-1,) + img.size()[2:]).float() / 255.0 - - torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6) + imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1) + + # NxKxCxHxW + if path_correct is None: + path_correct = torch.zeros(imgs.size(0)) <= 1 + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor( + [255, 0, 0] + ).view(1, -1, 1, 1) * (1 - path_correct) + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) ######################################################################