Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 11 Mar 2023 21:00:43 +0000 (22:00 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 11 Mar 2023 21:00:43 +0000 (22:00 +0100)
beaver.py
maze.py

index 920a446..a289867 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -75,11 +75,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # maze options
 
-parser.add_argument("--world_height", type=int, default=13)
+parser.add_argument("--maze_height", type=int, default=13)
 
-parser.add_argument("--world_width", type=int, default=21)
+parser.add_argument("--maze_width", type=int, default=21)
 
-parser.add_argument("--world_nb_walls", type=int, default=15)
+parser.add_argument("--maze_nb_walls", type=int, default=15)
 
 ######################################################################
 
@@ -262,7 +262,13 @@ class TaskMaze(Task):
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
-            maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+            maze.save_image(
+                f"result_{n_epoch:04d}.png",
+                mazes,
+                paths,
+                predicted_paths,
+                maze.path_correctness(mazes, predicted_paths),
+            )
 
             model.train(t)
 
@@ -276,9 +282,9 @@ task = TaskMaze(
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.batch_size,
-    height=args.world_height,
-    width=args.world_width,
-    nb_walls=args.world_nb_walls,
+    height=args.maze_height,
+    width=args.maze_width,
+    nb_walls=args.maze_nb_walls,
     device=device,
 )
 
diff --git a/maze.py b/maze.py
index f4a4840..e377d2f 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -185,7 +185,7 @@ def create_maze_data(
 ######################################################################
 
 
-def save_image(name, mazes, target_paths, predicted_paths=None):
+def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
     mazes, target_paths = mazes.cpu(), target_paths.cpu()
 
     colors = torch.tensor(
@@ -204,7 +204,7 @@ def save_image(name, mazes, target_paths, predicted_paths=None):
         .reshape(target_paths.size() + (-1,))
         .permute(0, 3, 1, 2)
     )
-    img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+    imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
@@ -213,11 +213,29 @@ def save_image(name, mazes, target_paths, predicted_paths=None):
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
-        img = torch.cat((img, predicted_paths.unsqueeze(1)), 1)
-
-    img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
-
-    torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6)
+        imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+
+    # NxKxCxHxW
+    if path_correct is None:
+        path_correct = torch.zeros(imgs.size(0)) <= 1
+    path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+    img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+        [255, 0, 0]
+    ).view(1, -1, 1, 1) * (1 - path_correct)
+    img = img.expand(
+        -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+    ).clone()
+    for k in range(imgs.size(1)):
+        img[
+            :,
+            :,
+            1 : 1 + imgs.size(3),
+            1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+        ] = imgs[:, k]
+
+    img = img.float() / 255.0
+
+    torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
 
 
 ######################################################################