+ def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+ print(f"DEBUG {quizzes.size()=}")
+ l = (quizzes.size(1) - 1) // 2
+ forward = (quizzes[:, 0] == self.token_forward).long()
+ backward = (quizzes[:, 0] == self.token_backward).long()
+ assert forward.equal(1 - backward)
+ first = quizzes[:, 1 : 1 + l]
+ second = quizzes[:, 1 + l : 1 + 2 * l]
+ prompts = forward[:, None] * first + backward[:, None] * second
+ answers = forward[:, None] * second + backward[:, None] * first
+
+ if prediction:
+ predicted_prompts = backward
+ predicted_answers = forward
+ else:
+ predicted_prompts = None
+ predicted_answers = None
+
+ self.problem.save_quizzes(
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts,
+ predicted_answers,
+ )
+