Update
authorFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 07:19:36 +0000 (08:19 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 07:19:36 +0000 (08:19 +0100)
beaver.py
mygpt.py

index b0e8a78..7adb804 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -26,9 +26,7 @@ else:
 
 ######################################################################
 
-parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
 
 parser.add_argument("--log_filename", type=str, default="train.log")
 
@@ -196,7 +194,6 @@ class TaskMaze(Task):
         )
         mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
         self.train_input = self.map2seq(mazes_train, paths_train)
-        self.nb_codes = self.train_input.max() + 1
 
         mazes_test, paths_test = maze.create_maze_data(
             nb_test_samples,
@@ -208,6 +205,8 @@ class TaskMaze(Task):
         mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
         self.test_input = self.map2seq(mazes_test, paths_test)
 
+        self.nb_codes = self.train_input.max() + 1
+
     def batches(self, split="train", nb_to_use=-1):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
index 5ea4668..df6eab6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -14,19 +14,6 @@ from torch.nn import functional as F
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -57,6 +44,19 @@ class BracketedSequence:
 ######################################################################
 
 
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+######################################################################
+
+
 class CacheWrapper(nn.Module):
     def __init__(self, *f):
         super().__init__()