From: Francois Fleuret Date: Tue, 26 Jul 2022 11:27:58 +0000 (+0200) Subject: Added a null token, which is the one to predict. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=38c9162209ddc1894da6805a3c7459d8c2b3a13d;p=mygpt.git Added a null token, which is the one to predict. --- diff --git a/main.py b/main.py index 4a332b8..e973291 100755 --- a/main.py +++ b/main.py @@ -204,7 +204,7 @@ class TaskPicoCLVR(Task): t_generated = [ ] for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ] input = torch.tensor(t, device = self.device) output = model(input) logits = output[0, -1]