- def reverse_time(self, c_quizzes):
- l = (c_quizzes.size(1) - 1) // 2
- direction = c_quizzes[:, 0:1]
- direction = self.token_forward * (
- direction == self.token_backward
- ) + self.token_backward * (direction == self.token_forward)
+ def forward_to_backward(self, c_quizzes):
+ prompts = c_quizzes[:, 1 : 1 + self.prompt_len]
+ answers = c_quizzes[:, 2 + self.prompt_len :]
+ return torch.cat(
+ [
+ c_quizzes.new_full((c_quizzes, 1), self.token_backward),
+ answers,
+ c_quizzes.new_full((c_quizzes, 1), self.token_backward),
+ prompts,
+ ],
+ dim=1,
+ )