Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 06:33:59 +0000 (08:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 06:33:59 +0000 (08:33 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 848ac9c..be5e1bd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -526,6 +526,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
 
+    to_recycle = None
+
     while nb_validated_per_model.sum() < nb_to_validate:
         # We use the model that has generated the fewest quizzes to
         # balance the number of quizzes per model overall
@@ -542,6 +544,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             model_for_generation=model_for_generation,
             temperature_hot=args.temperature_hot,
             temperature_cold=args.temperature_cold,
+            to_recycle=to_recycle,
         )
 
         # We discard the trivial ones, according to a criterion
@@ -561,6 +564,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         number_correct_responses = 0
         nb_remaining = [c_quizzes.size(0)]
+        rejected = []
 
         for r in range(args.nb_rounds):
             if c_quizzes.size(0) == 0:
@@ -577,11 +581,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
                 & (nb_sure_fail <= args.max_fail_to_validate)
             )
 
+            if not to_keep.all():
+                rejected.append(c_quizzes[to_keep == False])
+
             c_quizzes = c_quizzes[to_keep]
             number_correct_responses = number_correct_responses[to_keep]
 
             nb_remaining.append(c_quizzes.size(0))
 
+        to_recycle = torch.cat(rejected, dim=0) if len(rejected) > 0 else None
+
         if c_quizzes.size(0) > 0:
             nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
             recorded_validated.append(c_quizzes)
@@ -606,10 +615,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         v = " ".join([str(n) for n in nb_remaining])
-        log_string(f"filter c_quizzes {v}")
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h) filtering {v}"
         )
 
     validated_quizzes = torch.cat(recorded_validated, dim=0)
@@ -630,7 +638,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     if vq.size(0) > 0:
         number_correct_responses = 0
-        for r in range(10):
+
+        for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"):
             number_correct_responses += quiz_machine.models_successes(models, vq)
 
         comments = []
index 7516aed..13c157e 100755 (executable)
@@ -62,52 +62,6 @@ def one_batch_masked_inplace_autoregression(
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
 
-def masked_inplace_autoregression(
-    model,
-    batch_size,
-    input,
-    ar_mask,
-    seq_logproba,
-    logit_transformer=None,
-    deterministic_synthesis=False,
-    forbidden_tokens=None,
-    logit_biases=None,
-    progress_bar_desc=None,
-    device=torch.device("cpu"),
-):
-    assert input.size() == ar_mask.size()
-
-    batches = zip(
-        input.split(batch_size),
-        ar_mask.split(batch_size),
-        seq_logproba.split(batch_size),
-    )
-
-    if progress_bar_desc is not None:
-        batches = tqdm.tqdm(
-            batches,
-            dynamic_ncols=True,
-            desc=progress_bar_desc,
-            total=(input.size(0) + batch_size - 1) // batch_size,
-        )
-
-    with torch.autograd.no_grad():
-        t = model.training
-        model.eval()
-
-        for input, ar_mask, seq_logproba in batches:
-            one_batch_masked_inplace_autoregression(
-                model=model,
-                input=input,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                logit_transformer=logit_transformer,
-                deterministic_synthesis=deterministic_synthesis,
-            )
-
-        model.train(t)
-
-
 ######################################################################
 
 
@@ -147,6 +101,51 @@ class QuizMachine:
 
     ######################################################################
 
+    def autoregression(
+        model,
+        input,
+        ar_mask,
+        seq_logproba=None,
+        logit_transformer=None,
+        progress_bar_desc=None,
+    ):
+        assert input.size() == ar_mask.size()
+
+        if seq_logproba is None:
+            seq_logproba = torch.empty(input.size(0), device=self.device)
+
+        batches = zip(
+            input.split(self.batch_size),
+            ar_mask.split(self.batch_size),
+            seq_logproba.split(self.batch_size),
+        )
+
+        if progress_bar_desc is not None:
+            batches = tqdm.tqdm(
+                batches,
+                dynamic_ncols=True,
+                desc=progress_bar_desc,
+                total=(input.size(0) + self.batch_size - 1) // self.batch_size,
+            )
+
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            for input, ar_mask, seq_logproba in batches:
+                one_batch_masked_inplace_autoregression(
+                    model=model,
+                    input=input,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba,
+                    logit_transformer=logit_transformer,
+                    deterministic_synthesis=deterministic_synthesis,
+                )
+
+            model.train(t)
+
+    ######################################################################
+
     def data_input(self, model, split="train"):
         assert split in {"train", "test"}
 
@@ -194,16 +193,12 @@ class QuizMachine:
         ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
-        seq_logproba = torch.empty(quizzes.size(0), device=self.device)
-
-        masked_inplace_autoregression(
+        self.autoregression(
             model=model,
-            batch_size=self.batch_size,
             input=result,
             ar_mask=ar_mask,
             seq_logproba=seq_logproba,
             progress_bar_desc="accuracy",
-            device=self.device,
         )
 
         correct = (result == quizzes).min(dim=1).values.long()
@@ -400,13 +395,11 @@ class QuizMachine:
                 result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
             )
 
-            masked_inplace_autoregression(
+            self.autoregression(
                 model=model,
-                batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                device=self.device,
             )
 
             correct = (c_quizzes == result).long().min(dim=-1).values
@@ -420,13 +413,11 @@ class QuizMachine:
                 result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
             )
 
-            masked_inplace_autoregression(
+            self.autoregression(
                 model=model,
-                batch_size=self.batch_size,
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                device=self.device,
             )
 
             correct *= (reversed_c_quizzes == result).long().min(dim=-1).values
@@ -445,6 +436,7 @@ class QuizMachine:
         model_for_generation,
         temperature_hot=1.0,
         temperature_cold=1.0,
+        to_recycle=None,
     ):
         c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B"))
         c_quizzes = c_quizzes.to(self.device)
@@ -454,42 +446,41 @@ class QuizMachine:
         lt_noisy = lambda s, logits: logits / temperature_hot
         lt_clean = lambda s, logits: logits / temperature_cold
 
-        masked_inplace_autoregression(
+        self.autoregression(
             model=model_for_generation,
-            batch_size=self.batch_size,
             input=c_quizzes,
             ar_mask=self.make_ar_mask(
                 c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_noisy,
-            device=self.device,
         )
 
-        masked_inplace_autoregression(
+        if to_recycle is not None:
+            l = c_quizzes.size(1) // 4
+            self.logger(f"recycling {to_recycle.size(0)} rejected quizzes")
+            c_quizzes[: to_recycle.size(0), :l] = to_recycle[:, 3 * l :]
+
+        self.autoregression(
             model=model_for_generation,
-            batch_size=self.batch_size,
             input=c_quizzes,
             ar_mask=self.make_ar_mask(
                 c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
-            device=self.device,
         )
 
         c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
 
-        masked_inplace_autoregression(
+        self.autoregression(
             model=model_for_generation,
-            batch_size=self.batch_size,
             input=c_quizzes,
             ar_mask=self.make_ar_mask(
                 c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_clean,
-            device=self.device,
         )
 
         return c_quizzes.to("cpu")
@@ -514,16 +505,75 @@ class QuizMachine:
 
         lt_noisy = lambda s, logits: logits / temperature_hot
 
-        masked_inplace_autoregression(
+        self.autoregression(
             model=model_for_generation,
-            batch_size=self.batch_size,
             input=c_quizzes,
             ar_mask=self.make_ar_mask(
                 c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1)
             ),
             seq_logproba=seq_logproba,
             logit_transformer=lt_noisy,
-            device=self.device,
+        )
+
+        return c_quizzes.to("cpu")
+
+    ######################################################################
+
+    def generate_c_quizzes_2(
+        self,
+        nb,
+        model_for_generation,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
+    ):
+        warnings.warn(
+            "**************************** simple quiz generation", RuntimeWarning
+        )
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        lt_noisy = lambda s, logits: logits / temperature_hot
+        lt_clean = lambda s, logits: logits / temperature_cold
+
+        c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
+        c_quizzes = c_quizzes.to(self.device)
+
+        self.autoregression(
+            model=model_for_generation,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(
+                c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 0, 0)
+            ),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+        )
+
+        c_quizzes2 = self.problem.create_empty_quizzes(nb, ("B", "f_B", "A", "f_A"))
+        c_quizzes2 = c_quizzes2.to(self.device)
+
+        self.autoregression(
+            model=model_for_generation,
+            input=c_quizzes2,
+            ar_mask=self.make_ar_mask(
+                c_quizzes2,
+                ("B", "f_B", "A", "f_A"),
+                (1, 0, 0, 0),
+            ),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+        )
+
+        l = c_quizzes.size(1) // 4
+        c_quizzes[:, 2 * l : 3 * l] = c_quizzes2[:, :l]
+
+        self.autoregression(
+            model=model_for_generation,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(
+                c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            ),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
         )
 
         return c_quizzes.to("cpu")