######################################################################
-parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
parser.add_argument("--log_filename", type=str, default="train.log")
)
mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
self.train_input = self.map2seq(mazes_train, paths_train)
- self.nb_codes = self.train_input.max() + 1
mazes_test, paths_test = maze.create_maze_data(
nb_test_samples,
mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
self.test_input = self.map2seq(mazes_test, paths_test)
+ self.nb_codes = self.train_input.max() + 1
+
def batches(self, split="train", nb_to_use=-1):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input