Update
[beaver.git] / mygpt.py
index df6eab6..bd79676 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -197,7 +197,6 @@ class MyGPT(nn.Module):
         dropout=0.0,
         len_max=1e5,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
@@ -247,18 +246,18 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, bs):
+    def forward(self, bs, with_readout=True):
         bs.x = F.pad(bs.x, (1, -1))
         bs = self.embedding(bs)
         bs = self.trunk(bs)
-        bs = self.readout(bs)
+        if with_readout:
+            bs = self.readout(bs)
         return bs
 
 
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
     vocabulary_size = 10