From eb750bd5d68bb53f3044bd38b65930c9ff53b53d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 16:40:22 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index 2d38fab..e70b903 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -144,18 +144,13 @@ class QuizMachine: j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward i_a2p = quizzes[:, 0] == self.problem.token_backward j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward - assert torch.logical_or( - torch.logical_and(i_p2a, j_p2a), - torch.logical_and(i_a2p, j_a2p), - ).all() + assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all() return i_p2a, i_a2p def non_trivial(self, quizzes): quizzes = quizzes.clone() - n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward] - n_a2p = quizzes[:, 0] == self.problem.token_backward - a2p = quizzes[n_a2p] - quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p]) + i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) + quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) # a_fa_b_fb return torch.logical_not( self.problem.trivial_prompts_and_answers( quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :] @@ -219,14 +214,14 @@ class QuizMachine: show_part_to_predict=True, ): quizzes = quizzes.clone().to("cpu") - n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward] - n_a2p = quizzes[:, 0] == self.problem.token_backward - a2p = quizzes[n_a2p] - assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0) - quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p]) + i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) + p2a = quizzes[i_p2a] + a2p = quizzes[i_a2p] + assert p2a.size(0) + a2p.size(0) == quizzes.size(0) + quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) if show_part_to_predict: - predicted_prompts = n_a2p.long() + predicted_prompts = i_a2p.long() predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct @@ -313,22 +308,21 @@ class QuizMachine: correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - n_p2a = input[:, 0] == self.problem.token_forward - n_a2p = input[:, 0] == self.problem.token_backward + i_p2a, i_a2p = self.indices_p2a_and_a2p(input) - correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values + correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values - if self.back_accuracy and n_a2p.any(): + if self.back_accuracy and i_a2p.any(): # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.problem.p_a_flip(result[n_a2p]) - back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len] - _, correct[n_a2p] = compute_accuracy(back_input) + back_input = self.problem.p_a_flip(result[i_a2p]) + back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len] + _, correct[i_a2p] = compute_accuracy(back_input) if log_prefix is not None: - p2a_nb_correct = correct[n_p2a].sum() - p2a_nb_total = correct[n_p2a].size(0) - a2p_nb_correct = correct[n_a2p].sum() - a2p_nb_total = correct[n_a2p].size(0) + p2a_nb_correct = correct[i_p2a].sum() + p2a_nb_total = correct[i_p2a].size(0) + a2p_nb_correct = correct[i_a2p].sum() + a2p_nb_total = correct[i_a2p].size(0) self.logger( f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}" -- 2.39.5