Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 04:04:03 +0000 (06:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 04:04:03 +0000 (06:04 +0200)
grids.py
quiz_machine.py

index b531eb9..ba0131d 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -177,28 +177,17 @@ class Grids(problem.Problem):
             )
         else:
             flipped_from_forward = torch.cat(
-                [
-                    quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
-                    quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1],
-                    quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
-                    quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1],
-                ],
+                [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]],
                 dim=1,
             )
             flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
 
             flipped_from_backward = torch.cat(
-                [
-                    quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1],
-                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
-                    quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1],
-                    quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
-                ],
-                dim=1,
+                [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1
             )
             flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
 
-            m = (flipped[:, 0] == self.token_forward).long()
+            m = (quizzes[:, 0] == self.token_forward).long()[:, None]
 
             flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
 
index 182e9ff..b1f6be1 100755 (executable)
@@ -603,215 +603,3 @@ class QuizMachine:
         return c_quizzes.to("cpu")
 
     ######################################################################
-
-    def generate_c_quizzes_fixed_point(
-        self,
-        nb,
-        model_for_generation,
-        p2a_only=False,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-    ):
-        c_quizzes = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-        lt_clean = lambda s, logits: logits / temperature_cold
-
-        c_quizzes[...] = self.problem.token_backward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes)
-
-        c_quizzes = self.problem.p_a_flip(c_quizzes)
-
-        while True:
-            print("ITERATION")
-
-            c_quizzes = self.problem.p_a_flip(c_quizzes)
-
-            pred = c_quizzes.clone()
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_clean,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
-
-            c_quizzes = self.problem.p_a_flip(c_quizzes)
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-                seq_logproba=seq_logproba,
-                logit_transformer=lt_clean,
-                deterministic_synthesis=False,
-                device=self.device,
-            )
-
-            if pred[202:].equal(c_quizzes[202:]):
-                break
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes)
-
-        exit(0)
-
-        return c_quizzes.to("cpu")
-
-    ######################################################################
-
-    def generate_c_quizzes_mixing(
-        self,
-        nb,
-        model_for_generation,
-        p2a_only=False,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-    ):
-        c_quizzes = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        c_quizzes_1 = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        c_quizzes_2 = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-        lt_clean = lambda s, logits: logits / temperature_cold
-
-        ######################################################################
-
-        c_quizzes_1[...] = self.problem.token_backward
-        ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes_1,
-            ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
-
-        c_quizzes_2[...] = self.problem.token_backward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes_2,
-            ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
-
-        h = len(model_for_generation.trunk) // 2
-
-        with torch.autograd.no_grad():
-            t = model_for_generation.training
-            model_for_generation.eval()
-
-            bs1 = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(c_quizzes_1), end_layer=h
-            )
-            bs2 = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(c_quizzes_2), end_layer=h
-            )
-
-            alpha = 0.1
-
-            output = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
-                start_layer=h,
-            ).x
-
-            dist = torch.distributions.categorical.Categorical(logits=output)
-            c_quizzes[...] = dist.sample()
-
-            c_quizzes[...] = (
-                ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
-            )
-
-            model_for_generation.train(t)
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
-
-        ######################################################################
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
-
-        c_quizzes = self.problem.p_a_flip(c_quizzes)
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
-
-        print("DONE")
-        exit(0)
-
-        return c_quizzes.to("cpu")