input = F.pad(input, (0, 1)) # Add the next token, the one to predict
output = model(input)
logits = output[0, -1]
- if args.synthesis_sampling:
+ if args.deterministic_synthesis:
+ t_next = logits.argmax()
+ else:
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
- else:
- t_next = logits.argmax()
t_generated.append(self.vocab.lookup_token(t_next))
if t_generated[-1] == '<nul>': break