X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=131c822c76620076721cdb3a7722544dd6ea70b2;hb=22415499c0a91922e51f9e2cade009fd404351dc;hp=77c29ce909549fca9487e9e50564ce7e01f67932;hpb=621231cc5bb94f983c556a1b450b66067bec4165;p=picoclvr.git diff --git a/mygpt.py b/mygpt.py index 77c29ce..131c822 100755 --- a/mygpt.py +++ b/mygpt.py @@ -264,6 +264,7 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs)