Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 13:21:13 +0000 (15:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 13:21:13 +0000 (15:21 +0200)
grids.py
quiz_machine.py

index eaba99a..99a9240 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -136,19 +136,6 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
-    def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
-        assert check_structure(quizzes, struct)
-
-        ar_mask = quizzes.new_zeros(quizzes.size())
-
-        a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
-        a[:, 0, :] = mask[0]
-        a[:, 1, :] = mask[1]
-        a[:, 2, :] = mask[2]
-        a[:, 3, :] = mask[3]
-
-        return ar_mask
-
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
 
@@ -166,6 +153,29 @@ class Grids(problem.Problem):
 
         return result
 
+    def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+        assert check_structure(quizzes, struct)
+
+        ar_mask = quizzes.new_zeros(quizzes.size())
+
+        a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+        a[:, 0, :] = mask[0]
+        a[:, 1, :] = mask[1]
+        a[:, 2, :] = mask[2]
+        a[:, 3, :] = mask[3]
+
+        return ar_mask
+
+    def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
+        q = quizzes.reshape(-1, 4, S + 1)
+        return (
+            (q[:, 0, 0] == self.l2tok[struct[0]])
+            & (q[:, 1, 0] == self.l2tok[struct[1]])
+            & (q[:, 2, 0] == self.l2tok[struct[2]])
+            & (q[:, 3, 0] == self.l2tok[struct[3]])
+        )
+
     def __init__(
         self,
         max_nb_cached_chunks=None,
@@ -1368,12 +1378,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def trivial_prompts_and_answers(self, prompts, answers):
-        S = self.height * self.width
-        Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
-        f_Bs = answers[:, 1:]
-        return (Bs == f_Bs).long().min(dim=-1).values > 0
-
     def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
             tasks = self.all_tasks
@@ -1425,8 +1429,21 @@ if __name__ == "__main__":
     nb = 5
     quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
     print(grids.get_structure(quizzes))
-    blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
-    print(grids.get_structure(blah))
+    quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+    print(grids.get_structure(quizzes))
+
+    i = torch.rand(quizzes.size(0)) < 0.5
+
+    quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
+
+    j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
+
+    print(
+        i.equal(j),
+        grids.get_structure(quizzes[j]),
+        grids.get_structure(quizzes[j == False]),
+    )
+
     exit(0)
 
     # nb = 1000
index d62ba3b..bc2a358 100755 (executable)
@@ -17,36 +17,6 @@ from mygpt import BracketedSequence
 
 import threading
 
-######################################################################
-# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
-# | X != Y)
-
-
-# output is NxCxT and target is NxT
-def confusion(output, target, reduction="mean"):
-    N, C, T = output.shape
-    output = output.permute(0, 2, 1).reshape(-1, C)
-    target = target.flatten()
-    all_t = torch.arange(N * T, device=output.device)
-    output = output.log_softmax(dim=-1)
-    result = -output[all_t, target]
-
-    output[all_t, target] = float("-inf")
-    output = output.log_softmax(dim=-1)
-    e = output.exp()
-    output[all_t, target] = 0
-    result = result - (output * e).sum(-1)
-
-    if reduction == "none":
-        return result.reshape(N, T)
-    elif reduction == "mean":
-        return result.reshape(N, T).mean()
-    elif reduction == "sum":
-        return result.reshape(N, T).sum()
-    else:
-        raise ValueError(f"unknown reduction '{reduction}'.")
-
-
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -139,47 +109,6 @@ def masked_inplace_autoregression(
 
 
 class QuizMachine:
-    def indices_p2a_and_a2p(self, quizzes):
-        i_p2a = quizzes[:, 0] == self.problem.token_forward
-        j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
-        i_a2p = quizzes[:, 0] == self.problem.token_backward
-        j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
-        assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all()
-        return i_p2a, i_a2p
-
-    def non_trivial(self, quizzes):
-        quizzes = quizzes.clone()
-        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-        quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])  # a_fa_b_fb
-        return torch.logical_not(
-            self.problem.trivial_prompts_and_answers(
-                quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
-            )
-        )
-
-    def p_a_flip_half_in_place(self, quizzes):
-        i = torch.rand(quizzes.size(0)) < 0.5
-        if i.any():
-            quizzes[i] = self.problem.p_a_flip(quizzes[i])
-
-    def generate_token_sequences(self, nb):
-        prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
-        if self.prompt_len is None:
-            self.prompt_len = prompts.size(1)
-
-        if self.answer_len is None:
-            self.answer_len = answers.size(1)
-
-        assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
-        result = []
-
-        for prompt, answer in zip(prompts, answers):
-            result.append(torch.cat([prompt, answer], dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
     def __init__(
         self,
         problem,
@@ -191,8 +120,6 @@ class QuizMachine:
     ):
         super().__init__()
 
-        self.nb_token_values = problem.nb_token_values()
-
         self.problem = problem
         self.back_accuracy = back_accuracy
         self.batch_size = batch_size
@@ -205,47 +132,8 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        quizzes,
-        mistakes=None,
-        show_part_to_predict=True,
-    ):
-        quizzes = quizzes.clone().to("cpu")
-        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-        p2a = quizzes[i_p2a]
-        a2p = quizzes[i_a2p]
-        assert p2a.size(0) + a2p.size(0) == quizzes.size(0)
-        quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])
-
-        if show_part_to_predict:
-            predicted_prompts = i_a2p.long()
-            predicted_answers = 1 - predicted_prompts
-            if mistakes is not None:
-                # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-                predicted_prompts *= mistakes.to("cpu")
-                predicted_answers *= mistakes.to("cpu")
-            else:
-                # 0/2 ~ not-to-predict / to predict
-                predicted_prompts *= 2
-                predicted_answers *= 2
-        else:
-            predicted_prompts = None
-            predicted_answers = None
-
-        self.problem.save_quiz_illustrations(
-            result_dir,
-            filename_prefix,
-            quizzes[:, : self.prompt_len],
-            quizzes[:, self.prompt_len :],
-            predicted_prompts,
-            predicted_answers,
-        )
-
     def vocabulary_size(self):
-        return self.nb_token_values
+        return self.problem.nb_token_values
 
     ######################################################################
 
@@ -289,11 +177,12 @@ class QuizMachine:
     def produce_results(
         self, n_epoch, model, input, result_dir, deterministic_synthesis
     ):
-        def compute_accuracy(input, log_prefix=None):
-            input = input.to(self.device)
-            ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
-            result = input.clone() * (1 - ar_mask)
-            seq_logproba = torch.empty(input.size(0), device=self.device)
+        def predict(input, struct, mask):
+            ar_mask = self.problem.make_ar_mask(
+                quizzes=quizzes, struct=struct, mask=mask
+            )
+            result = quizzes * (1 - ar_mask)
+            seq_logproba = torch.empty(fwd_quizzes, device=self.device)
 
             masked_inplace_autoregression(
                 model=model,
@@ -306,37 +195,29 @@ class QuizMachine:
                 device=self.device,
             )
 
-            correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
-
-            i_p2a, i_a2p = self.indices_p2a_and_a2p(input)
-
-            correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values
-
-            if self.back_accuracy and i_a2p.any():
-                # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.problem.p_a_flip(result[i_a2p])
-                back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len]
-                _, correct[i_a2p] = compute_accuracy(back_input)
-
-            if log_prefix is not None:
-                p2a_nb_correct = correct[i_p2a].sum()
-                p2a_nb_total = correct[i_p2a].size(0)
-                a2p_nb_correct = correct[i_a2p].sum()
-                a2p_nb_total = correct[i_a2p].size(0)
-
-                self.logger(
-                    f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"
-                )
+            nb_correct = (result == quizzes).min(dim=1).long()
 
             return result, correct
 
-        test_result, test_correct = compute_accuracy(input, log_prefix="test")
+        input = input.to(self.device)
+        i = self.problem.indices_select(quizzes=input, struct=struct)
 
-        n_test_p2a = input[:, 0] == self.problem.token_forward
+        test_result_fwd, test_correct_fwd = predict(
+            input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+        )
 
-        p2a_test_correct = test_correct[n_test_p2a]
+        input_bck = self.problem.reconfigure(
+            predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0],
+            struct=("A", "f_A", "B", "f_B"),
+        )
+
+        l = input_bck.size(1)
+        input_bck[:, 3 * l :] = input[i == False][:, :l]
+        test_result_bck, test_correct_bck = predict(
+            input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+        )
 
-        main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0)
+        main_test_accuracy = test_correct.sum() / test_correct.size(0)
 
         ##############################
 
@@ -351,19 +232,27 @@ class QuizMachine:
 
     ######################################################################
 
-    def create_w_quizzes(
-        self, model, nb_train_samples, nb_test_samples, p2a_only=False
-    ):
-        model.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
-        model.test_w_quizzes = self.generate_token_sequences(nb_test_samples)
+    def flip_half_in_place(self, quizzes):
+        r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5
+        i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B"))
+        quizzes[i & r] = self.problem.reconfigure(
+            quizzes[i & r], struct=("f_B", "f_A", "B", "A")
+        )
+        j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A"))
+        quizzes[j & r] = self.problem.reconfigure(
+            quizzes[j & r], struct=("A", "f_A", "B", "f_B")
+        )
 
-        if not p2a_only:
-            self.p_a_flip_half_in_place(model.train_w_quizzes)
-            self.p_a_flip_half_in_place(model.test_w_quizzes)
+    def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
+        model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
+        model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
+
+        self.flip_half_in_place(model.train_w_quizzes)
+        self.flip_half_in_place(model.test_w_quizzes)
 
     ######################################################################
 
-    def renew_train_w_quizzes(self, model, p2a_only=False):
+    def renew_train_w_quizzes(self, model):
         if hasattr(model, "hard_w_quizzes"):
             self.logger(
                 f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
@@ -379,19 +268,18 @@ class QuizMachine:
                 model.train_w_quizzes[...] = torch.cat(
                     [
                         model.hard_w_quizzes,
-                        self.generate_token_sequences(
+                        self.problem.generate_w_quizzes(
                             model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
                         ),
                     ],
                     dim=0,
                 )
         else:
-            model.train_w_quizzes[...] = self.generate_token_sequences(
+            model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
                 model.train_w_quizzes.size(0)
             )
 
-        if not p2a_only:
-            self.p_a_flip_half_in_place(model.train_w_quizzes)
+        self.flip_half_in_place(model.train_w_quizzes)
 
     ######################################################################
 
@@ -481,9 +369,7 @@ class QuizMachine:
             # -------------------------------
 
             # f(A), A, f(B) | B
-            c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to(
-                self.device
-            )
+            c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device)
             result = c_quizzes.clone()
 
             ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
@@ -512,7 +398,6 @@ class QuizMachine:
         self,
         nb,
         model_for_generation,
-        p2a_only=False,
         temperature_hot=1.0,
         temperature_cold=1.0,
     ):
@@ -533,68 +418,42 @@ class QuizMachine:
         # )
         # lt_clean = None
 
-        if p2a_only:
-            c_quizzes[...] = self.problem.token_forward
+        c_quizzes[...] = self.problem.token_backward
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_noisy,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_clean,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
-
-        else:
-            c_quizzes[...] = self.problem.token_backward
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_noisy,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_clean,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
 
-            c_quizzes = self.problem.p_a_flip(c_quizzes)
+        c_quizzes = self.problem.p_a_flip(c_quizzes)
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_clean,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
 
         return c_quizzes.to("cpu")