X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=7f0c9e6b7bdd89a77de30ad5cfa9f47d1c8b6257;hb=fc570d4ccd5d5dee36271d34ff5c672a50a82101;hp=37fe6aff89b6b2b07acb1d1d8e5be632a025cc01;hpb=0dbca4cef7405fb92689e5d2542f1d4761d658a3;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 37fe6af..7f0c9e6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -127,10 +127,11 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = torch.cat((x.new_zeros(x.size(0), 1), x), 1) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - return x + return x[:, :-1] ######################################################################