X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=main.py;h=a83107b4bda4b2b7f5ba57d2a1db1dcbdd5ff734;hb=a1a7cb9e680378db521f2a1e2139db0e2db903de;hp=e973291a15517bb3c33a96f6f3f1ab7db4545b86;hpb=38c9162209ddc1894da6805a3c7459d8c2b3a13d;p=mygpt.git diff --git a/main.py b/main.py index e973291..a83107b 100755 --- a/main.py +++ b/main.py @@ -204,8 +204,9 @@ class TaskPicoCLVR(Task): t_generated = [ ] for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ] + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] input = torch.tensor(t, device = self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: @@ -333,6 +334,7 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: