Added a null token, which is the one to predict.
[mygpt.git] / mygpt.py
index 37fe6af..5370ffa 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -111,7 +111,6 @@ class MyGPT(nn.Module):
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),
-                    nn.Linear(in_features = dim_model, out_features = dim_model),
                 ),
                 Residual(
                     nn.LayerNorm(dim_model),
@@ -127,10 +126,11 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
+        x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        return x
+        return x[:, :-1]
 
 ######################################################################