Update
[beaver.git] / beaver.py
index 920a446..7adb804 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -26,9 +26,7 @@ else:
 
 ######################################################################
 
-parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
 
 parser.add_argument("--log_filename", type=str, default="train.log")
 
@@ -75,11 +73,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # maze options
 
-parser.add_argument("--world_height", type=int, default=13)
+parser.add_argument("--maze_height", type=int, default=13)
 
-parser.add_argument("--world_width", type=int, default=21)
+parser.add_argument("--maze_width", type=int, default=21)
 
-parser.add_argument("--world_nb_walls", type=int, default=15)
+parser.add_argument("--maze_nb_walls", type=int, default=15)
 
 ######################################################################
 
@@ -196,7 +194,6 @@ class TaskMaze(Task):
         )
         mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
         self.train_input = self.map2seq(mazes_train, paths_train)
-        self.nb_codes = self.train_input.max() + 1
 
         mazes_test, paths_test = maze.create_maze_data(
             nb_test_samples,
@@ -208,6 +205,8 @@ class TaskMaze(Task):
         mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
         self.test_input = self.map2seq(mazes_test, paths_test)
 
+        self.nb_codes = self.train_input.max() + 1
+
     def batches(self, split="train", nb_to_use=-1):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -227,6 +226,7 @@ class TaskMaze(Task):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
@@ -256,13 +256,19 @@ class TaskMaze(Task):
             input = self.test_input[:32]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
-
             ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
-            maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+            maze.save_image(
+                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+                mazes,
+                paths,
+                predicted_paths,
+                maze.path_correctness(mazes, predicted_paths),
+            )
 
             model.train(t)
 
@@ -276,9 +282,9 @@ task = TaskMaze(
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.batch_size,
-    height=args.world_height,
-    width=args.world_width,
-    nb_walls=args.world_nb_walls,
+    height=args.maze_height,
+    width=args.maze_width,
+    nb_walls=args.maze_nb_walls,
     device=device,
 )