Update
authorFrançois Fleuret <francois@fleuret.org>
Tue, 21 Mar 2023 22:37:47 +0000 (23:37 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 21 Mar 2023 22:37:47 +0000 (23:37 +0100)
README.txt
mygpt.py

index dc13a4f..8265b48 100644 (file)
@@ -2,3 +2,7 @@ To train the shortest-path solving GPT, and train the one-shot MLP
 read-out:
 
   ./beaver.py --oneshot
+
+Same, lighter settings (~95% test success instead of ~99%):
+
+  ./beaver.py --nb_train_samples=25000 --nb_test_samples=10000 --nb_epochs=10 --oneshot
index 0b63ac8..232b604 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -85,7 +85,7 @@ class AddPositionalEncoding(nn.Module):
 
     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
 
-    def forward(self, bs):
+    def forward(self, bs, order=None):
         if bs.first == 0:
             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
                 :, None
@@ -97,6 +97,10 @@ class AddPositionalEncoding(nn.Module):
             self.pe = torch.sin(
                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
             )
+
+            if order is not None:
+                self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe))
+
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
@@ -201,10 +205,10 @@ class MyGPT(nn.Module):
 
         assert dim_model % nb_heads == 0
 
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
+        self.embedding = CacheWrapper(
+            nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
         )
+        self.pe = AddPositionalEncoding(len_max)
 
         trunk_blocks = []
 
@@ -246,9 +250,13 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, bs, mode="standard"):
-        bs.x = F.pad(bs.x, (1, -1))
+    def forward(self, bs, mode="standard", order=None):
+        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+        if order is not None:
+            order = F.pad(order + 1, (1, -1))
         bs = self.embedding(bs)
+        bs = self.pe(bs, order)
+
         if mode == "standard":
             bs = self.trunk(bs)
             bs = self.readout(bs)