Update.
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index f4a4840..e377d2f 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -185,7 +185,7 @@ def create_maze_data(
 ######################################################################
 
 
-def save_image(name, mazes, target_paths, predicted_paths=None):
+def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
     mazes, target_paths = mazes.cpu(), target_paths.cpu()
 
     colors = torch.tensor(
@@ -204,7 +204,7 @@ def save_image(name, mazes, target_paths, predicted_paths=None):
         .reshape(target_paths.size() + (-1,))
         .permute(0, 3, 1, 2)
     )
-    img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+    imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
@@ -213,11 +213,29 @@ def save_image(name, mazes, target_paths, predicted_paths=None):
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
-        img = torch.cat((img, predicted_paths.unsqueeze(1)), 1)
-
-    img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
-
-    torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6)
+        imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+
+    # NxKxCxHxW
+    if path_correct is None:
+        path_correct = torch.zeros(imgs.size(0)) <= 1
+    path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+    img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+        [255, 0, 0]
+    ).view(1, -1, 1, 1) * (1 - path_correct)
+    img = img.expand(
+        -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+    ).clone()
+    for k in range(imgs.size(1)):
+        img[
+            :,
+            :,
+            1 : 1 + imgs.size(3),
+            1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+        ] = imgs[:, k]
+
+    img = img.float() / 255.0
+
+    torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
 
 
 ######################################################################