Update
[beaver.git] / mygpt.py
index 5ea4668..bd79676 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -14,19 +14,6 @@ from torch.nn import functional as F
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -57,6 +44,19 @@ class BracketedSequence:
 ######################################################################
 
 
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+######################################################################
+
+
 class CacheWrapper(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -197,7 +197,6 @@ class MyGPT(nn.Module):
         dropout=0.0,
         len_max=1e5,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
@@ -247,18 +246,18 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, bs):
+    def forward(self, bs, with_readout=True):
         bs.x = F.pad(bs.x, (1, -1))
         bs = self.embedding(bs)
         bs = self.trunk(bs)
-        bs = self.readout(bs)
+        if with_readout:
+            bs = self.readout(bs)
         return bs
 
 
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
     vocabulary_size = 10