result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
- prediction=True,
+ show_to_be_predicted=True,
)
- def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+ def save_quizzes(
+ self,
+ result_dir,
+ filename_prefix,
+ quizzes,
+ show_to_be_predicted=False,
+ mistakes=None,
+ ):
quizzes = quizzes.clone()
forward = quizzes[quizzes[:, 0] == self.token_forward]
ib = quizzes[:, 0] == self.token_backward
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if prediction:
- predicted_prompts = ib
- predicted_answers = torch.logical_not(ib)
+ if show_to_be_predicted:
+ predicted_prompts = ib.long()
+ predicted_answers = 1 - predicted_prompts
+ if mistakes is not None:
+ # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
+ predicted_prompts *= mistakes
+ predicted_answers *= mistakes
+ else:
+ # 0/2 ~ not-to-predict / to predict
+ predicted_prompts *= 2
+ predicted_answers *= 2
else:
predicted_prompts = None
predicted_answers = None
device=self.device,
)
+ mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1
+
self.save_quizzes(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=result[:72],
- prediction=True,
+ show_to_be_predicted=True,
+ mistakes=mistakes[:72],
)
return main_test_accuracy