Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 06:10:58 +0000 (08:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 06:10:58 +0000 (08:10 +0200)
grids.py
main.py
quiz_machine.py

index 4f07d70..c2ff0d1 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -126,6 +126,8 @@ class Grids(problem.Problem):
         tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
+        self.token_forward = len(self.colors)
+        self.token_backward = self.token_forward + 1
         self.height = 10
         self.width = 10
         self.cache_rec_coo = {}
@@ -157,7 +159,7 @@ class Grids(problem.Problem):
 
     def frame2img(self, x, scale=15):
         x = x.reshape(x.size(0), self.height, -1)
-        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+        m = torch.logical_and(x >= 0, x < len(self.colors)).long()
         x = self.colors[x * m].permute(0, 3, 1, 2)
         s = x.shape
         x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
@@ -192,13 +194,19 @@ class Grids(problem.Problem):
         margin=8,
     ):
         S = self.height * self.width
-        As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
-        f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
+        As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view(
+            -1, self.height, self.width
+        )
+        f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view(
+            -1, self.height, self.width
+        )
+        Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
             -1, self.height, self.width
         )
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
         prompts = torch.cat([As, f_As, Bs], dim=2)
-        answers = answers.reshape(answers.size(0), self.height, self.width)
+        answers = answers[:, 1 : S + 1].reshape(
+            answers.size(0), self.height, self.width
+        )
 
         if predicted_prompts is None:
             predicted_prompts = 255
@@ -307,7 +315,7 @@ class Grids(problem.Problem):
     ######################################################################
 
     def nb_token_values(self):
-        return len(self.colors)
+        return len(self.colors) + 2
 
     # @torch.compile
     def rec_coo(
@@ -1180,8 +1188,9 @@ class Grids(problem.Problem):
 
     def trivial_prompts_and_answers(self, prompts, answers):
         S = self.height * self.width
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
-        f_Bs = answers
+        Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
+        f_Bs = answers[:, 1:]
+        print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}")
         return (Bs == f_Bs).long().min(dim=-1).values > 0
 
     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
@@ -1189,8 +1198,8 @@ class Grids(problem.Problem):
             tasks = self.all_tasks
 
         S = self.height * self.width
-        prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
-        answers = torch.zeros(nb, S, dtype=torch.int64)
+        prompts = torch.full((nb, 3 * S + 3), self.token_forward)
+        answers = torch.full((nb, S + 1), self.token_forward)
 
         bunch = zip(prompts, answers)
 
@@ -1203,10 +1212,16 @@ class Grids(problem.Problem):
             )
 
         for prompt, answer in bunch:
-            A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
-            f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
-            B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
-            f_B = answer.view(self.height, self.width)
+            A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view(
+                self.height, self.width
+            )
+            f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view(
+                self.height, self.width
+            )
+            B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
+                self.height, self.width
+            )
+            f_B = answer[1 : S + 1].view(self.height, self.width)
             task = tasks[torch.randint(len(tasks), (1,)).item()]
             task(A, f_A, B, f_B)
 
diff --git a/main.py b/main.py
index 0d0d373..ab87b56 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -90,11 +90,13 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=1.5)
+parser.add_argument("--temperature_hot", type=float, default=1.5)
+
+parser.add_argument("--temperature_cold", type=float, default=0.75)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
-parser.add_argument("--forward_only", action="store_true", default=False)
+parser.add_argument("--p2a_only", action="store_true", default=False)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -374,8 +376,8 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        n_forward = input[:, 0] == quiz_machine.token_forward
-        to_store = from_w & n_forward.to("cpu")
+        n_p2a = input[:, 0] == quiz_machine.token_p2a
+        to_store = from_w & n_p2a.to("cpu")
         if to_store.any():
             hard_w_quizzes.append(
                 (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))
@@ -454,13 +456,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # We balance the number of quizzes per model
 
         model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
-        print(nb_validated, "using", model_for_generation.id)
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
             model_for_generation=model_for_generation,
-            forward_only=args.forward_only,
-            generation_temperature=args.generation_temperature,
+            p2a_only=args.p2a_only,
+            temperature_hot=args.temperature_hot,
+            temperature_cold=args.temperature_cold,
         )
 
         c_quizzes = keep_good_quizzes(models, c_quizzes)
@@ -536,14 +538,14 @@ for k in range(args.nb_gpts):
         model=model,
         nb=args.nb_train_samples,
         for_train=True,
-        forward_only=args.forward_only,
+        p2a_only=args.p2a_only,
     )
 
     quiz_machine.create_w_quizzes(
         model=model,
         nb=args.nb_test_samples,
         for_train=False,
-        forward_only=args.forward_only,
+        p2a_only=args.p2a_only,
     )
 
     models.append(model)
@@ -673,7 +675,7 @@ for n_epoch in range(args.nb_epochs):
         quiz_machine.renew_w_quizzes(
             model=model,
             for_train=True,
-            forward_only=args.forward_only,
+            p2a_only=args.p2a_only,
         )
 
     if args.log_command is not None:
index 032305a..51c3f08 100755 (executable)
@@ -138,83 +138,72 @@ def masked_inplace_autoregression(
 
 
 class QuizMachine:
-    def indices_forward_and_backward(self, quizzes):
-        i_forward = quizzes[:, 0] == self.token_forward
-        j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
-        i_backward = quizzes[:, 0] == self.token_backward
-        j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
+    def indices_p2a_and_a2p(self, quizzes):
+        i_p2a = quizzes[:, 0] == self.problem.token_forward
+        j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
+        i_a2p = quizzes[:, 0] == self.problem.token_backward
+        j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
         assert torch.logical_or(
-            torch.logical_and(i_forward, j_forward),
-            torch.logical_and(i_backward, j_backward),
+            torch.logical_and(i_p2a, j_p2a),
+            torch.logical_and(i_a2p, j_a2p),
         ).all()
-        return i_forward, i_backward
+        return i_p2a, i_a2p
 
     def non_trivial(self, quizzes):
         quizzes = quizzes.clone()
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+        n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
+        n_a2p = quizzes[:, 0] == self.problem.token_backward
+        a2p = quizzes[n_a2p]
+        quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
         return torch.logical_not(
             self.problem.trivial_prompts_and_answers(
-                quizzes[:, 1 : 1 + self.prompt_len],
-                quizzes[:, 2 + self.prompt_len :],
+                quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
             )
         )
 
-    def reverse_time(self, quizzes):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+    def p_a_flip(self, quizzes):
+        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
 
-        forward_to_backward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
-                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
-                quizzes[:, 1 : 1 + self.prompt_len],
-            ],
+        p2a_to_a2p = torch.cat(
+            [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]],
             dim=1,
         )
 
-        forward_to_backward[:, 0] = self.token_backward
-        forward_to_backward[:, 1 + self.answer_len] = self.token_backward
+        p2a_to_a2p[:, 0] = self.problem.token_backward
+        p2a_to_a2p[:, self.answer_len] = self.problem.token_backward
 
-        backward_to_forward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.answer_len :],
-                quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
-                quizzes[:, 1 : 1 + self.answer_len],
-            ],
+        a2p_to_p2a = torch.cat(
+            [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]],
             dim=1,
         )
 
-        backward_to_forward[:, 0] = self.token_forward
-        backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
+        a2p_to_p2a[:, 0] = self.problem.token_forward
+        a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward
 
-        m = i_forward.long()[:, None]
+        m = i_p2a.long()[:, None]
 
-        return m * forward_to_backward + (1 - m) * backward_to_forward
+        return m * p2a_to_a2p + (1 - m) * a2p_to_p2a
 
-    def reverse_random_half_in_place(self, quizzes):
+    def p_a_flip_half_in_place(self, quizzes):
         i = torch.rand(quizzes.size(0)) < 0.5
         if i.any():
-            quizzes[i] = self.reverse_time(quizzes[i])
+            quizzes[i] = self.p_a_flip(quizzes[i])
 
     def make_ar_mask(self, quizzes, first=False):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+        i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
 
         t = torch.arange(quizzes.size(1), device=quizzes.device)
 
         if first:
-            m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
-            m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
+            m_p2a = (t >= 1).long() * (t < self.prompt_len).long()
+            m_a2p = (t >= 1).long() * (t < self.answer_len).long()
         else:
-            m_forward = (t >= 2 + self.prompt_len).long()
-            m_backward = (t >= 2 + self.answer_len).long()
+            m_p2a = (t >= 1 + self.prompt_len).long()
+            m_a2p = (t >= 1 + self.answer_len).long()
 
-        m = i_forward.long()[:, None]
+        m = i_p2a.long()[:, None]
 
-        return m * m_forward + (1 - m) * m_backward
+        return m * m_p2a + (1 - m) * m_a2p
 
     def generate_token_sequences(self, nb):
         prompts, answers = self.problem.generate_prompts_and_answers(nb)
@@ -230,14 +219,7 @@ class QuizMachine:
         result = []
 
         for prompt, answer in zip(prompts, answers):
-            a = [
-                torch.tensor([self.token_forward]),
-                prompt,
-                torch.tensor([self.token_forward]),
-                answer,
-            ]
-
-            result.append(torch.cat(a, dim=0)[None, :])
+            result.append(torch.cat([prompt, answer], dim=0)[None, :])
 
         return torch.cat(result, dim=0)
 
@@ -252,10 +234,7 @@ class QuizMachine:
     ):
         super().__init__()
 
-        v = problem.nb_token_values()
-        self.token_forward = v
-        self.token_backward = v + 1
-        self.nb_token_values = v + 2
+        self.nb_token_values = problem.nb_token_values()
 
         self.problem = problem
         self.back_accuracy = back_accuracy
@@ -278,14 +257,14 @@ class QuizMachine:
         show_part_to_predict=True,
     ):
         quizzes = quizzes.clone().to("cpu")
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+        n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
+        n_a2p = quizzes[:, 0] == self.problem.token_backward
+        a2p = quizzes[n_a2p]
+        assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
+        quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
 
         if show_part_to_predict:
-            predicted_prompts = n_backward.long()
+            predicted_prompts = n_a2p.long()
             predicted_answers = 1 - predicted_prompts
             if mistakes is not None:
                 # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@ -371,29 +350,27 @@ class QuizMachine:
 
             correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
 
-            n_forward = input[:, 0] == self.token_forward
-            n_backward = input[:, 0] == self.token_backward
+            n_p2a = input[:, 0] == self.problem.token_forward
+            n_a2p = input[:, 0] == self.problem.token_backward
 
-            correct[n_forward] = (
-                (input[n_forward] == result[n_forward]).long().min(dim=1).values
-            )
+            correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values
 
-            if self.back_accuracy and n_backward.any():
+            if self.back_accuracy and n_a2p.any():
                 # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.reverse_time(result[n_backward])
+                back_input = self.p_a_flip(result[n_a2p])
                 back_input[:, 2 + self.prompt_len :] = input[
-                    n_backward, 1 : 1 + self.answer_len
+                    n_a2p, 1 : 1 + self.answer_len
                 ]
-                _, correct[n_backward] = compute_accuracy(back_input)
+                _, correct[n_a2p] = compute_accuracy(back_input)
 
             if log_prefix is not None:
-                forward_nb_correct = correct[n_forward].sum()
-                forward_nb_total = correct[n_forward].size(0)
-                backward_nb_correct = correct[n_backward].sum()
-                backward_nb_total = correct[n_backward].size(0)
+                p2a_nb_correct = correct[n_p2a].sum()
+                p2a_nb_total = correct[n_p2a].size(0)
+                a2p_nb_correct = correct[n_a2p].sum()
+                a2p_nb_total = correct[n_a2p].size(0)
 
                 self.logger(
-                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"
                 )
 
             return result, correct
@@ -402,11 +379,11 @@ class QuizMachine:
             model.test_w_quizzes[:2000], log_prefix="test"
         )
 
-        n_test_forward = model.test_w_quizzes[:2000, 0] == self.token_forward
+        n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward
 
-        forward_test_correct = test_correct[n_test_forward]
+        p2a_test_correct = test_correct[n_test_p2a]
 
-        main_test_accuracy = forward_test_correct.sum() / forward_test_correct.size(0)
+        main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0)
 
         ##############################
 
@@ -421,11 +398,11 @@ class QuizMachine:
 
     ######################################################################
 
-    def create_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+    def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False):
         input = self.generate_token_sequences(nb)
 
-        if not forward_only:
-            self.reverse_random_half_in_place(input)
+        if not p2a_only:
+            self.p_a_flip_half_in_place(input)
 
         if for_train:
             model.train_w_quizzes = input
@@ -434,7 +411,7 @@ class QuizMachine:
 
     ######################################################################
 
-    def renew_w_quizzes(self, model, for_train=True, forward_only=False):
+    def renew_w_quizzes(self, model, for_train=True, p2a_only=False):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
 
         if for_train and hasattr(model, "hard_w_quizzes"):
@@ -458,8 +435,8 @@ class QuizMachine:
         else:
             input[...] = self.generate_token_sequences(input.size(0))
 
-        if not forward_only:
-            self.reverse_random_half_in_place(input)
+        if not p2a_only:
+            self.p_a_flip_half_in_place(input)
 
     ######################################################################
 
@@ -553,20 +530,25 @@ class QuizMachine:
     ###############################################################
 
     def generate_c_quizzes(
-        self, nb, model_for_generation, forward_only=False, generation_temperature=1.0
+        self,
+        nb,
+        model_for_generation,
+        p2a_only=False,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
     ):
         c_quizzes = torch.empty(
             nb,
-            self.prompt_len + self.answer_len + 2,
+            self.prompt_len + self.answer_len,
             device=self.device,
             dtype=torch.int64,
         )
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
-        if forward_only:
-            c_quizzes[:, 0] = self.token_forward
-            c_quizzes[:, 1 + self.prompt_len] = self.token_forward
+        if p2a_only:
+            c_quizzes[:, 0] = self.problem.token_forward
+            c_quizzes[:, self.prompt_len] = self.problem.token_forward
 
             masked_inplace_autoregression(
                 model=model_for_generation,
@@ -574,7 +556,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
-                temperature=generation_temperature,
+                temperature=temperature_hot,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -585,14 +567,14 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=1.0,
+                temperature=temperature_cold,
                 deterministic_synthesis=False,
                 device=self.device,
             )
 
         else:
-            c_quizzes[:, 0] = self.token_backward
-            c_quizzes[:, 1 + self.answer_len] = self.token_backward
+            c_quizzes[:, 0] = self.problem.token_backward
+            c_quizzes[:, self.answer_len] = self.problem.token_backward
 
             masked_inplace_autoregression(
                 model=model_for_generation,
@@ -600,7 +582,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
-                temperature=generation_temperature,
+                temperature=temperature_hot,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -611,12 +593,12 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=0.75,
+                temperature=temperature_cold,
                 deterministic_synthesis=False,
                 device=self.device,
             )
 
-            c_quizzes = self.reverse_time(c_quizzes)
+            c_quizzes = self.p_a_flip(c_quizzes)
 
             masked_inplace_autoregression(
                 model=model_for_generation,
@@ -624,7 +606,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=0.75,
+                temperature=temperature_cold,
                 deterministic_synthesis=False,
                 device=self.device,
             )