Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 14:40:22 +0000 (16:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 14:40:22 +0000 (16:40 +0200)
quiz_machine.py

index 2d38fab..e70b903 100755 (executable)
@@ -144,18 +144,13 @@ class QuizMachine:
         j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
         i_a2p = quizzes[:, 0] == self.problem.token_backward
         j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
-        assert torch.logical_or(
-            torch.logical_and(i_p2a, j_p2a),
-            torch.logical_and(i_a2p, j_a2p),
-        ).all()
+        assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all()
         return i_p2a, i_a2p
 
     def non_trivial(self, quizzes):
         quizzes = quizzes.clone()
-        n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
-        n_a2p = quizzes[:, 0] == self.problem.token_backward
-        a2p = quizzes[n_a2p]
-        quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
+        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
+        quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])  # a_fa_b_fb
         return torch.logical_not(
             self.problem.trivial_prompts_and_answers(
                 quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
@@ -219,14 +214,14 @@ class QuizMachine:
         show_part_to_predict=True,
     ):
         quizzes = quizzes.clone().to("cpu")
-        n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
-        n_a2p = quizzes[:, 0] == self.problem.token_backward
-        a2p = quizzes[n_a2p]
-        assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
-        quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
+        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
+        p2a = quizzes[i_p2a]
+        a2p = quizzes[i_a2p]
+        assert p2a.size(0) + a2p.size(0) == quizzes.size(0)
+        quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])
 
         if show_part_to_predict:
-            predicted_prompts = n_a2p.long()
+            predicted_prompts = i_a2p.long()
             predicted_answers = 1 - predicted_prompts
             if mistakes is not None:
                 # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@ -313,22 +308,21 @@ class QuizMachine:
 
             correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
 
-            n_p2a = input[:, 0] == self.problem.token_forward
-            n_a2p = input[:, 0] == self.problem.token_backward
+            i_p2a, i_a2p = self.indices_p2a_and_a2p(input)
 
-            correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values
+            correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values
 
-            if self.back_accuracy and n_a2p.any():
+            if self.back_accuracy and i_a2p.any():
                 # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.problem.p_a_flip(result[n_a2p])
-                back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
-                _, correct[n_a2p] = compute_accuracy(back_input)
+                back_input = self.problem.p_a_flip(result[i_a2p])
+                back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len]
+                _, correct[i_a2p] = compute_accuracy(back_input)
 
             if log_prefix is not None:
-                p2a_nb_correct = correct[n_p2a].sum()
-                p2a_nb_total = correct[n_p2a].size(0)
-                a2p_nb_correct = correct[n_a2p].sum()
-                a2p_nb_total = correct[n_a2p].size(0)
+                p2a_nb_correct = correct[i_p2a].sum()
+                p2a_nb_total = correct[i_p2a].size(0)
+                a2p_nb_correct = correct[i_a2p].sum()
+                a2p_nb_total = correct[i_a2p].size(0)
 
                 self.logger(
                     f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"