X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=d6879dc08a29f05cac1998bc1ab16e46db07821c;hb=e68f19634d3282e39a488d146480b19bb23e8652;hp=212e1a5d7f013b04e40f15e36a9f90b4de63483d;hpb=7b6a1a4f12459fd18a2006fa8f11589f2b2cd87b;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 212e1a5..d6879dc 100755 --- a/mygpt.py +++ b/mygpt.py @@ -104,7 +104,7 @@ class MyGPT(nn.Module): for _ in range(nb_blocks): trunk_blocks += [ Residual( - nn.LayerNorm(dim_model), + nn.LayerNorm((dim_model,)), QKVAttention( dim_in = dim_model, dim_qk = dim_keys, @@ -114,7 +114,7 @@ class MyGPT(nn.Module): ), ), Residual( - nn.LayerNorm(dim_model), + nn.LayerNorm((dim_model,)), nn.Linear(in_features = dim_model, out_features = dim_hidden), nn.ReLU(), nn.Linear(in_features = dim_hidden, out_features = dim_model), @@ -131,7 +131,8 @@ class MyGPT(nn.Module): x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - return x[:, :-1] + x = F.pad(x, (0, 0, 0, -1)) + return x ######################################################################