Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 10:37:06 +0000 (13:37 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 10:37:06 +0000 (13:37 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index 9d95034..6137834 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -383,100 +383,76 @@ def run_tests(model, quizz_machine, deterministic_synthesis):
 ######################################################################
 
 
+def valid_c_quizzes(recorded, criteria):
+    result = [q[criteria(c)] for q, c in recorded]
+    return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
+
+
+######################################################################
+
+
 def create_c_quizzes(
     models,
     quizz_machine,
     nb_for_train=1000,
     nb_for_test=100,
-    min_ave_seq_logproba=None,
 ):
-    # We will store the generated quizzes for each number of
-    # correct prediction
-    recorded = dict([(n, []) for n in range(len(models) + 1)])
+    recorded = []
 
-    model_indexes = []
     sum_logits, sum_nb_c_quizzes = 0, 0
 
-    def nb_generated():
-        return sum([sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()])
-
-    def nb_validated():
-        return sum(
-            [
-                sum([x.size(0) for x in recorded[n]])
-                for n in range(args.min_to_validate, args.max_to_validate + 1)
-            ]
-        )
-
     nb_to_create = nb_for_train + nb_for_test
 
-    warnings.warn(
-        f"{args.nb_gpts=} {args.nb_models_for_generation=} {args.min_to_validate=} {args.max_to_validate=}"
+    # ------------------------------------------------------------
+
+    standard_validity = lambda nb_correct: torch.logical_and(
+        nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
     )
 
-    while nb_validated() < nb_to_create:
-        (
-            new_c_quizzes,
-            nb_correct,
-            ave_seq_logproba,
-        ) = quizz_machine.gang_create_c_quizzes(
-            nb=nb_to_create,
-            nb_models_for_generation=args.nb_models_for_generation,
-            models=models,
-            mode=args.generation_mode,
+    while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create:
+        model_for_generation = models[torch.randint(len(models), (1,))]
+
+        c_quizzes, ave_seq_logproba = quizz_machine.generate_quizzes(
+            nb_to_create,
+            model_for_generation=model_for_generation,
             reverse_cleanup=args.reverse_cleanup,
-            min_ave_seq_logproba=min_ave_seq_logproba,
-            n_epoch=n_epoch,
-            result_dir=args.result_dir,
         )
 
-        sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
-        sum_nb_c_quizzes += new_c_quizzes.size(0)
+        sum_logits += c_quizzes.size(0) * ave_seq_logproba
+        sum_nb_c_quizzes += c_quizzes.size(0)
+
+        nb_correct = quizz_machine.comput_correctness(c_quizzes, models)
 
         if args.dirty_debug:
             nb_correct = torch.randint(
-                len(models) + 1, nb_correct.size(), device=new_c_quizzes.device
+                len(models) + 1, nb_correct.size(), device=c_quizzes.device
             )
 
-        for n in range(nb_correct.max() + 1):
-            recorded[n].append(new_c_quizzes[nb_correct == n].clone())
+        recorded.append((c_quizzes, nb_correct))
 
         nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
         nv = " ".join([str(x.item()) for x in nv])
 
-        log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
+        nb_validated = valid_c_quizzes(recorded, standard_validity).size(0)
 
-    # concatenate and shuffle
-    for n in recorded.keys():
-        if len(recorded[n]) > 0:
-            q = torch.cat(recorded[n], dim=0)
-            q = q[torch.randperm(q.size(0), device=q.device)]
-            recorded[n] = q
-        else:
-            del recorded[n]
+        log_string(f"keep c_quizzes kept {nv} total {nb_validated} / {nb_to_create}")
 
-    new_c_quizzes = torch.cat(
-        [recorded[n] for n in range(args.min_to_validate, args.max_to_validate + 1)],
-        dim=0,
-    )
+    # ------------------------------------------------------------
 
-    new_c_quizzes = new_c_quizzes[
-        torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[
-            : nb_for_train + nb_for_test
-        ]
-    ]
+    new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    for n in recorded.keys():
+    for n in range(len(models) + 1):
         s = (
             "_validated"
             if n >= args.min_to_validate and n <= args.max_to_validate
             else ""
         )
+
         quizz_machine.problem.save_quizzes(
-            recorded[n][:72],
+            valid_c_quizzes(recorded, criteria=lambda nb_correct: nb_correct == n)[:72],
             args.result_dir,
             f"culture_c_quiz_{n_epoch:04d}_N{n}{s}",
         )
@@ -511,57 +487,40 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-min_ave_seq_logproba = None
-
 for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
-    a = [(model.id, float(model.main_test_accuracy)) for model in models]
-    a.sort(key=lambda p: p[0])
-    s = " ".join([f"{p[1]*100:.02f}%" for p in a])
-    log_string(f"current accuracies {s}")
-
-    # select the model with lowest accuracy
-    models.sort(key=lambda model: model.main_test_accuracy)
-    model = models[0]
+    weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
 
     log_string(
-        f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+        f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
     )
 
     # improve it
-    one_epoch(model, quizz_machine)
-
-    quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+    one_epoch(weakest_model, quizz_machine)
 
     log_string(
         f"train_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
     # test it
-    run_tests(model, quizz_machine, deterministic_synthesis=False)
+    run_tests(weakest_model, quizz_machine, deterministic_synthesis=False)
 
     log_string(
         f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
+    # replace a fraction of the w_quizzes with a fresh ones
+    quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+
     if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        ave_seq_logproba = create_c_quizzes(
+        create_c_quizzes(
             models,
             quizz_machine,
             nb_for_train=nb_new_c_quizzes_for_train,
             nb_for_test=nb_new_c_quizzes_for_test,
-            min_ave_seq_logproba=min_ave_seq_logproba,
         )
 
-        # We keep the first average logits as a reference
-        # if min_ave_seq_logproba is None:
-        # min_ave_seq_logproba = ave_seq_logproba
-        # else:
-        # log_string(
-        # f"min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba}"
-        # )
-
         # We update everyone
         for model in models:
             run_tests(model, quizz_machine, deterministic_synthesis=False)
index 6f7492d..8dc23a5 100755 (executable)
@@ -17,43 +17,6 @@ from mygpt import BracketedSequence
 
 ######################################################################
 
-
-class Gang(nn.Module):
-    def __init__(self, models, nb_models_for_generation, mode="groupthink"):
-        super().__init__()
-        self.models = nn.ModuleList(models)
-        self.nb_models_for_generation = nb_models_for_generation
-        self.mode = mode
-
-    def forward(self, bs):
-        # If first = 0, we are re-starting an auto-regressive process,
-        # that's the right moment to randomize who gonna do it
-        if bs.first == 0:
-            self.models_to_use = [
-                self.models[k]
-                for k in torch.randperm(len(self.models))[
-                    : self.nb_models_for_generation
-                ]
-            ]
-
-        all_the_logits = torch.cat(
-            [model(bs).x[None] for model in self.models_to_use], dim=0
-        )
-
-        if self.mode == "groupthink":
-            y = all_the_logits.mean(dim=0)
-        elif self.mode == "groupwork":
-            m = torch.rand(all_the_logits.size(), device=all_the_logits.device)
-            m = (m.sort(dim=0).indices == 0).long()
-            y = (y * m).sum(dim=0)
-        else:
-            raise ValueError(f"Invalid mode {self.mode}")
-
-        return BracketedSequence(y, bs.first, bs.nb)
-
-
-######################################################################
-
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
 # 1s where tokens should be generated. The others are kept
 # unchanged.
@@ -374,9 +337,7 @@ class QuizzMachine:
 
     ###############################################################
 
-    def generate_quizzes(
-        self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
-    ):
+    def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
@@ -406,7 +367,6 @@ class QuizzMachine:
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=False,
-            # progress_bar_desc="sampling c_quizzes",
             device=self.device,
         )
 
@@ -422,7 +382,6 @@ class QuizzMachine:
             seq_logproba=seq_logproba,
             temperature=temperature,
             deterministic_synthesis=True,
-            # progress_bar_desc="sampling c_quizzes",
             device=self.device,
         )
 
@@ -436,56 +395,7 @@ class QuizzMachine:
                 seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=True,
-                # progress_bar_desc="sampling c_quizzes",
                 device=self.device,
             )
 
         return c_quizzes, seq_logproba.mean()
-
-    ######################################################################
-
-    def create_c_quizzes(
-        self,
-        nb,
-        model_for_generation,
-        models_for_validation,
-        min_ave_seq_logproba,
-        reverse_cleanup,
-        n_epoch,
-        result_dir,
-    ):
-        c_quizzes, ave_seq_logproba = self.generate_quizzes(
-            nb,
-            model_for_generation=model_for_generation,
-            min_ave_seq_logproba=min_ave_seq_logproba,
-            reverse_cleanup=reverse_cleanup,
-        )
-
-        nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
-
-        return c_quizzes, nb_correct, ave_seq_logproba
-
-    ######################################################################
-
-    def gang_create_c_quizzes(
-        self,
-        nb,
-        nb_models_for_generation,
-        models,
-        mode,
-        min_ave_seq_logproba,
-        reverse_cleanup,
-        n_epoch,
-        result_dir,
-    ):
-        model_for_generation = Gang(models, nb_models_for_generation, mode)
-        models_for_validation = models
-        return self.create_c_quizzes(
-            nb=nb,
-            model_for_generation=model_for_generation,
-            models_for_validation=models_for_validation,
-            min_ave_seq_logproba=min_ave_seq_logproba,
-            reverse_cleanup=reverse_cleanup,
-            n_epoch=n_epoch,
-            result_dir=result_dir,
-        )