Update.
[beaver.git] / beaver.py
index 920a446..a289867 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -75,11 +75,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # maze options
 
-parser.add_argument("--world_height", type=int, default=13)
+parser.add_argument("--maze_height", type=int, default=13)
 
-parser.add_argument("--world_width", type=int, default=21)
+parser.add_argument("--maze_width", type=int, default=21)
 
-parser.add_argument("--world_nb_walls", type=int, default=15)
+parser.add_argument("--maze_nb_walls", type=int, default=15)
 
 ######################################################################
 
@@ -262,7 +262,13 @@ class TaskMaze(Task):
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
-            maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+            maze.save_image(
+                f"result_{n_epoch:04d}.png",
+                mazes,
+                paths,
+                predicted_paths,
+                maze.path_correctness(mazes, predicted_paths),
+            )
 
             model.train(t)
 
@@ -276,9 +282,9 @@ task = TaskMaze(
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.batch_size,
-    height=args.world_height,
-    width=args.world_width,
-    nb_walls=args.world_nb_walls,
+    height=args.maze_height,
+    width=args.maze_width,
+    nb_walls=args.maze_nb_walls,
     device=device,
 )