+ def indices_forward_and_backward(self, quizzes):
+ i_forward = quizzes[:, 0] == self.token_forward
+ j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
+ i_backward = quizzes[:, 0] == self.token_backward
+ j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
+ assert torch.logical_or(
+ torch.logical_and(i_forward, j_forward),
+ torch.logical_and(i_backward, j_backward),
+ ).all()
+ return i_forward, i_backward
+
+ def reverse_time(self, quizzes):
+ i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+
+ forward_to_backward = torch.cat(
+ [
+ quizzes[:, 0:1],
+ quizzes[:, 2 + self.prompt_len :],
+ quizzes[:, 1 + self.prompt_len : 2 + self.prompt_len],
+ quizzes[:, 1 : 1 + self.prompt_len],
+ ],
+ dim=1,
+ )
+ forward_to_backward[:, 0] = self.token_backward
+ forward_to_backward[:, 1 + self.answer_len] = self.token_backward
+
+ backward_to_forward = torch.cat(
+ [
+ quizzes[:, 0:1],
+ quizzes[:, 2 + self.answer_len :],
+ quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
+ quizzes[:, 1 : 1 + self.answer_len],
+ ],
+ dim=1,
+ )
+
+ backward_to_forward[:, 0] = self.token_forward
+ backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
+
+ m = i_forward.long()[:, None]
+
+ return m * forward_to_backward + (1 - m) * backward_to_forward
+
+ def make_ar_mask(self, quizzes, first=False):
+ i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+
+ t = torch.arange(quizzes.size(1), device=quizzes.device)
+
+ if first:
+ m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
+ m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
+ else:
+ m_forward = (t >= 2 + self.prompt_len).long()
+ m_backward = (t >= 2 + self.answer_len).long()
+
+ m = i_forward.long()[:, None]
+
+ return m * m_forward + (1 - m) * m_backward