) # Needed to initialize the model's cache
for s in range(to_generate.min(), to_generate.max() + 1):
output = self(BracketedSequence(input, s, 1)).x
- logits = output[:, s] / temperature
+ logits = output[:, s]
+
+ logits = logits.log_softmax(dim=-1) / temperature
+
if forbidden_tokens is not None:
logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+
if forced_biases is not None:
logits = logits + forced_biases[None, :]
+
if deterministic_synthesis:
t_next = logits.argmax(1)
else:
device=self.device,
)
+ # Should not be necessary though, the autoregression is done
+ # in eval mode
+ sum_logits = sum_logits.detach()
+
average_logits = sum_logits / quizzes.numel()
# It's a bit brutal to do it twice, we should probably have a