j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
i_a2p = quizzes[:, 0] == self.problem.token_backward
j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
- assert torch.logical_or(
- torch.logical_and(i_p2a, j_p2a),
- torch.logical_and(i_a2p, j_a2p),
- ).all()
+ assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all()
return i_p2a, i_a2p
def non_trivial(self, quizzes):
quizzes = quizzes.clone()
- n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
- n_a2p = quizzes[:, 0] == self.problem.token_backward
- a2p = quizzes[n_a2p]
- quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
+ i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
+ quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) # a_fa_b_fb
return torch.logical_not(
self.problem.trivial_prompts_and_answers(
quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
show_part_to_predict=True,
):
quizzes = quizzes.clone().to("cpu")
- n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
- n_a2p = quizzes[:, 0] == self.problem.token_backward
- a2p = quizzes[n_a2p]
- assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
- quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
+ i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
+ p2a = quizzes[i_p2a]
+ a2p = quizzes[i_a2p]
+ assert p2a.size(0) + a2p.size(0) == quizzes.size(0)
+ quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])
if show_part_to_predict:
- predicted_prompts = n_a2p.long()
+ predicted_prompts = i_a2p.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
- n_p2a = input[:, 0] == self.problem.token_forward
- n_a2p = input[:, 0] == self.problem.token_backward
+ i_p2a, i_a2p = self.indices_p2a_and_a2p(input)
- correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values
+ correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values
- if self.back_accuracy and n_a2p.any():
+ if self.back_accuracy and i_a2p.any():
# accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.problem.p_a_flip(result[n_a2p])
- back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
- _, correct[n_a2p] = compute_accuracy(back_input)
+ back_input = self.problem.p_a_flip(result[i_a2p])
+ back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len]
+ _, correct[i_a2p] = compute_accuracy(back_input)
if log_prefix is not None:
- p2a_nb_correct = correct[n_p2a].sum()
- p2a_nb_total = correct[n_p2a].size(0)
- a2p_nb_correct = correct[n_a2p].sum()
- a2p_nb_total = correct[n_a2p].size(0)
+ p2a_nb_correct = correct[i_p2a].sum()
+ p2a_nb_total = correct[i_p2a].size(0)
+ a2p_nb_correct = correct[i_a2p].sum()
+ a2p_nb_total = correct[i_a2p].size(0)
self.logger(
f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"