input,
ar_mask,
seq_logproba,
- temperature=1.0,
- deterministic_synthesis=False,
+ temperature,
+ deterministic_synthesis,
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
- return not self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
+ return torch.logical_not(
+ self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
)
def reverse_time(self, quizzes):