+class QuizzMachine:
+ def indices_forward_and_backward(self, quizzes):
+ i_forward = quizzes[:, 0] == self.token_forward
+ j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
+ i_backward = quizzes[:, 0] == self.token_backward
+ j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
+ assert torch.logical_or(
+ torch.logical_and(i_forward, j_forward),
+ torch.logical_and(i_backward, j_backward),
+ ).all()
+ return i_forward, i_backward
+
+ def reverse_time(self, quizzes):
+ i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+
+ forward_to_backward = torch.cat(
+ [
+ quizzes[:, 0:1],
+ quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
+ quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
+ quizzes[:, 1 : 1 + self.prompt_len],
+ ],
+ dim=1,
+ )