X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=7adb804cf690dd4caf5a080d4e6c331c8285ab30;hb=c4eb660976808b873f32fe873819c4988aaf2ea5;hp=b0e8a78beed5666177307470bf7af031f2c5d55f;hpb=f2e47caba9966d03bff15d3058fa208a0778b160;p=beaver.git diff --git a/beaver.py b/beaver.py index b0e8a78..7adb804 100755 --- a/beaver.py +++ b/beaver.py @@ -26,9 +26,7 @@ else: ###################################################################### -parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasoning task." -) +parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.") parser.add_argument("--log_filename", type=str, default="train.log") @@ -196,7 +194,6 @@ class TaskMaze(Task): ) mazes_train, paths_train = mazes_train.to(device), paths_train.to(device) self.train_input = self.map2seq(mazes_train, paths_train) - self.nb_codes = self.train_input.max() + 1 mazes_test, paths_test = maze.create_maze_data( nb_test_samples, @@ -208,6 +205,8 @@ class TaskMaze(Task): mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) self.test_input = self.map2seq(mazes_test, paths_test) + self.nb_codes = self.train_input.max() + 1 + def batches(self, split="train", nb_to_use=-1): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input