+ def non_trivial(self, quizzes):
+ quizzes = quizzes.clone()
+ n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+ n_backward = quizzes[:, 0] == self.token_backward
+ backward = quizzes[n_backward]
+ quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+ return torch.logical_not(
+ self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
+ )
+