X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=7ff10358e77cce589ca9d1d53a5a5682ebb2e451;hb=c0019b5af155be6a8af02bf71a62c43af1d7a178;hp=954f4f088c29116052f3258dd13eda2f5a6b3b0c;hpb=f082ce9255f87b6c1bfebfac00f94820a16e04f1;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 954f4f0..7ff1035 100755 --- a/mygpt.py +++ b/mygpt.py @@ -125,11 +125,10 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): - x = F.pad(x, (1, 0)) + x = F.pad(x, (1, -1)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - x = F.pad(x, (0, 0, 0, -1)) return x ######################################################################