+ imgs = c_mazes.unsqueeze(1)
+
+ if target_paths is not None:
+ target_paths = target_paths.cpu()
+
+ c_target_paths = (
+ colors[target_paths.reshape(-1)]
+ .reshape(target_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+
+ imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
+
+ if predicted_paths is not None:
+ predicted_paths = predicted_paths.cpu()
+ c_predicted_paths = (
+ colors[predicted_paths.reshape(-1)]
+ .reshape(predicted_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+ imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+ if score_paths is not None:
+ score_paths = score_paths.cpu()
+ c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+ c_score_paths = (
+ c_score_paths * colors[4].reshape(1, 3, 1, 1)
+ + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1)
+ ).long()
+ c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+ mazes.unsqueeze(1) != v_empty
+ )
+ imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
+
+ # NxKxCxHxW
+ if path_correct is None:
+ path_correct = torch.zeros(imgs.size(0)) <= 1
+ path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+ img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+ [255, 0, 0]
+ ).view(1, -1, 1, 1) * (1 - path_correct)
+ img = img.expand(
+ -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+ ).clone()
+ for k in range(imgs.size(1)):
+ img[
+ :,
+ :,
+ 1 : 1 + imgs.size(3),
+ 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+ ] = imgs[:, k]
+
+ img = img.float() / 255.0
+
+ torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)