result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
- show_to_be_predicted=True,
+ n_backward=self.train_w_quizzes[:72, 0] == self.token_backward,
)
def save_quizzes(
result_dir,
filename_prefix,
quizzes,
- show_to_be_predicted=False,
+ n_backward=None,
mistakes=None,
):
quizzes = quizzes.clone()
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if show_to_be_predicted:
- predicted_prompts = ib.long()
+ if n_backward is None:
+ predicted_prompts = None
+ predicted_answers = None
+ else:
+ predicted_prompts = n_backward.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- else:
- predicted_prompts = None
- predicted_answers = None
self.problem.save_quizzes(
result_dir,
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
- show_to_be_predicted=True,
+ n_backward=self.test_w_quizzes[:72, 0] == self.token_backward,
mistakes=test_correct[:72] * 2 - 1,
)