S = self.height * self.width
Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
f_Bs = answers
- return (B_s == f_Bs).long().min(dim=-1).values > 0
+ return (Bs == f_Bs).long().min(dim=-1).values > 0
def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"):
if tasks is None:
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
- return not self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
+ return torch.logical_not(
+ self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
)
def reverse_time(self, quizzes):