Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 13:41:13 +0000 (15:41 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 13:41:13 +0000 (15:41 +0200)
grids.py
main.py
quiz_machine.py

index 5778f85..406c0b7 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -1413,6 +1413,7 @@ class Grids(problem.Problem):
                 m = (d < self.height * self.width).long()
                 X[i, j] = c[-1]
                 f_X[...] = m * c[-1] + (1 - m) * f_X
+                f_X[i, j] = 0
 
                 if accept_full or (d * (X == 0)).max() == self.height * self.width:
                     break
diff --git a/main.py b/main.py
index 61820dd..7feb3b9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -441,7 +441,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
     nb_to_validate = nb_for_train + nb_for_test
-    nb_to_generate_per_iteration = nb_to_validate
+    nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
     nb_validated = 0
 
     recorded_validated = []
@@ -485,6 +485,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # This is nb_quizzes x nb_models
         number_correct_responses = 0
 
+        remains = [c_quizzes.size(0)]
+
         for r in range(args.nb_rounds):
             number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
 
@@ -500,7 +502,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             c_quizzes = c_quizzes[to_keep]
             number_correct_responses = number_correct_responses[to_keep]
 
-            log_string(f"round {r} remains {c_quizzes.size(0)}")
+            remains.append(c_quizzes.size(0))
 
             if c_quizzes.size(0) == 0:
                 break
@@ -528,6 +530,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         else:
             e = "???"
 
+        v = " ".join([x.item() for x in remains])
+        log_string(f"filter c_quizzes {v}")
+
         log_string(
             f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
@@ -552,6 +557,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
+
+        number_correct_responses = 0
+        for r in range(args.nb_rounds):
+            number_correct_responses += quiz_machine.models_successes(models, vq)
+
+        with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f:
+            for n, r in enumerate(number_correct_responses):
+                v = " ".join([str(n.item()) for n in r])
+                f.write(f"{n}: {v}\n")
+
         quiz_machine.save_quiz_illustrations(
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
index d6c686e..4b07de3 100755 (executable)
@@ -605,3 +605,136 @@ class QuizMachine:
         return c_quizzes.to("cpu")
 
     ######################################################################
+
+    def generate_c_quizzes_mixing(
+        self,
+        nb,
+        model_for_generation,
+        p2a_only=False,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
+    ):
+        c_quizzes = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        c_quizzes_1 = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        c_quizzes_2 = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        lt_noisy = lambda s, logits: logits / temperature_hot
+        lt_clean = lambda s, logits: logits / temperature_cold
+
+        ######################################################################
+
+        c_quizzes_1[...] = self.problem.token_backward
+        ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes_1,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
+
+        c_quizzes_2[...] = self.problem.token_backward
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes_2,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
+
+        h = len(model_for_generation.trunk) // 2
+
+        with torch.autograd.no_grad():
+            t = model_for_generation.training
+            model_for_generation.eval()
+
+            bs1 = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(c_quizzes_1), end_layer=h
+            )
+            bs2 = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(c_quizzes_2), end_layer=h
+            )
+
+            alpha = 0.5
+
+            output = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
+                start_layer=h,
+            ).x
+
+            dist = torch.distributions.categorical.Categorical(logits=output)
+            c_quizzes[...] = dist.sample()
+
+            c_quizzes[...] = (
+                ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
+            )
+
+            model_for_generation.train(t)
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
+
+        ######################################################################
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
+
+        c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
+
+        print("DONE")
+        exit(0)
+
+        return c_quizzes.to("cpu")