X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=3bce361424ec4a8b37150c300dc0df05b01de7ef;hb=c3621f9a75cd4d79410d90a29dc9fdec401eaa2d;hp=7ff10358e77cce589ca9d1d53a5a5682ebb2e451;hpb=b2e05688f21ae9f49298c8e291940211b0e3007e;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 7ff1035..3bce361 100755 --- a/mygpt.py +++ b/mygpt.py @@ -97,6 +97,10 @@ class MyGPT(nn.Module): AddPositionalEncoding(len_max), ) + # Small embedding initialization + with torch.no_grad(): + self.embedding[0].weight.normal_(0, 2e-2) + trunk_blocks = [ ] for _ in range(nb_blocks):