projects
/
mygpt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygpt.git]
/
main.py
diff --git
a/main.py
b/main.py
index
e973291
..
a83107b
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-204,8
+204,9
@@
class TaskPicoCLVR(Task):
t_generated = [ ]
for j in range(nb_tokens):
t_generated = [ ]
for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ]
+ [ 0 ]
]
+ t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
input = torch.tensor(t, device = self.device)
input = torch.tensor(t, device = self.device)
+ input = F.pad(input, (0, 1)) # Add the next token, the one to predict
output = model(input)
logits = output[0, -1]
if args.synthesis_sampling:
output = model(input)
logits = output[0, -1]
if args.synthesis_sampling:
@@
-333,6
+334,7
@@
class TaskWiki103(Task):
for j in range(nb_tokens):
input = self.tensorize([ t_primer + t_generated ]).to(self.device)
for j in range(nb_tokens):
input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+ input = F.pad(input, (0, 1)) # Add the next token, the one to predict
output = model(input)
logits = output[0, -1]
if args.synthesis_sampling:
output = model(input)
logits = output[0, -1]
if args.synthesis_sampling: