Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 15:08:59 +0000 (17:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 15:08:59 +0000 (17:08 +0200)
main.py

diff --git a/main.py b/main.py
index d98031e..65695af 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -65,7 +65,7 @@ parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
-parser.add_argument("--lambda_H", type=float, default=0.0)
+parser.add_argument("--reboot", action="store_true", default=False)
 
 parser.add_argument("--schedule_free", action="store_true", default=False)
 
@@ -395,7 +395,9 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     nb_train_samples, acc_train_loss = 0, 0.0
 
     full_input, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
+        args.nb_train_samples,
+        model.train_c_quiz_bags + common_c_quiz_bags,
+        args.c_quiz_multiplier,
     )
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
@@ -472,76 +474,6 @@ c_quizzes_procedure = [
 ######################################################################
 
 
-def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
-    # Save generated quizzes with the successive generation steps
-
-    recorder = []
-
-    c_quizzes = quiz_machine.generate_c_quizzes(
-        64,
-        model_for_generation=model,
-        procedure=c_quizzes_procedure,
-        recorder=recorder,
-    )
-
-    # This is nb_quizzes x nb_models
-
-    l = [
-        quiz_machine.models_logprobas(
-            model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        for model in models
-    ]
-
-    seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
-    probas = seq_logprobas.exp()
-
-    comments = []
-
-    for l in seq_logprobas:
-        comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
-
-    ##
-
-    c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
-    predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
-    nb_steps = c_quizzes.size(1)
-    c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
-    predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
-
-    # We have comments only for the final quiz, not the successive
-    # steps, so we have to add nb_steps-1 empty comments
-
-    steps_comments = []
-    for c in comments:
-        steps_comments += [""] * (nb_steps - 1) + [c]
-
-    filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=c_quizzes,
-        predicted_parts=predicted_parts,
-        comments=steps_comments,
-        nrow=nb_steps * 2,  # two quiz per row
-    )
-
-    log_string(f"wrote {filename}")
-
-
-######################################################################
-
-
 def model_proba_solutions(model, quizzes):
     l = (
         quiz_machine.models_logprobas(
@@ -561,6 +493,9 @@ def model_proba_solutions(model, quizzes):
     return l.exp()
 
 
+######################################################################
+
+
 def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
@@ -804,60 +739,69 @@ class Thinker(nn.Module):
 
 ######################################################################
 
-models = []
-
-
-def compute_causal_attzero(t_q, t_k):
-    return t_q < t_k
 
+def create_models():
+    models = []
 
-if args.schedule_free:
-    import schedulefree
+    def compute_causal_attzero(t_q, t_k):
+        return t_q < t_k
 
-for k in range(args.nb_gpts):
-    log_string(f"creating model {k}")
+    if args.schedule_free:
+        import schedulefree
+
+    for k in range(args.nb_gpts):
+        log_string(f"creating model {k}")
+
+        model = mygpt.MyGPT(
+            vocabulary_size=vocabulary_size,
+            dim_model=args.dim_model,
+            dim_keys=args.dim_keys,
+            dim_hidden=args.dim_hidden,
+            nb_heads=args.nb_heads,
+            nb_blocks=args.nb_blocks,
+            compute_attzero=compute_causal_attzero,
+            dropout=args.dropout,
+        ).to(main_device)
+
+        class UpperBoundStd(nn.Module):
+            def __init__(self, std_max=1.0):
+                super().__init__()
+                self.std_max = std_max
+
+            def forward(self, x):
+                std = x.std(dim=-1, keepdim=True)
+                y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
+                return y
+
+        if args.logit_std_max > 0:
+            model.readout.f = nn.Sequential(
+                model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
+            )
 
-    model = mygpt.MyGPT(
-        vocabulary_size=vocabulary_size,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        compute_attzero=compute_causal_attzero,
-        dropout=args.dropout,
-    ).to(main_device)
+        model.id = k
+        model.train_c_quiz_bags = []
+        model.test_c_quiz_bags = []
 
-    class UpperBoundStd(nn.Module):
-        def __init__(self, std_max=1.0):
-            super().__init__()
-            self.std_max = std_max
+        if args.schedule_free:
+            model.optimizer = schedulefree.AdamWScheduleFree(
+                model.parameters(), lr=args.learning_rate
+            )
+        else:
+            model.optimizer = torch.optim.Adam(
+                model.parameters(), lr=args.learning_rate
+            )
 
-        def forward(self, x):
-            std = x.std(dim=-1, keepdim=True)
-            y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
-            return y
+        model.test_accuracy = 0.0
+        model.gen_test_accuracy = 0.0
+        model.gen_state_dict = copy.deepcopy(model.state_dict())
+        models.append(model)
 
-    if args.logit_std_max > 0:
-        model.readout.f = nn.Sequential(
-            model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
-        )
+    return models
 
-    model.id = k
-    model.train_c_quiz_bags = []
-    model.test_c_quiz_bags = []
 
-    if args.schedule_free:
-        model.optimizer = schedulefree.AdamWScheduleFree(
-            model.parameters(), lr=args.learning_rate
-        )
-    else:
-        model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+common_c_quiz_bags = []
 
-    model.test_accuracy = 0.0
-    model.best_test_accuracy = 0.0
-    model.best_dict = copy.deepcopy(model.state_dict())
-    models.append(model)
+models = create_models()
 
 ######################################################################
 
@@ -878,8 +822,8 @@ if args.resume:
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.test_accuracy = d["test_accuracy"]
-            model.best_test_accuracy = d["best_test_accuracy"]
-            model.best_dict = d["best_dict"]
+            model.gen_test_accuracy = d["gen_test_accuracy"]
+            model.gen_state_dict = d["gen_state_dict"]
             model.train_c_quiz_bags = d["train_c_quiz_bags"]
             model.test_c_quiz_bags = d["test_c_quiz_bags"]
             log_string(f"successfully loaded {filename}")
@@ -906,10 +850,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
@@ -1128,41 +1072,64 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
-    cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models])
-    log_string(f"current_best_test_accuracies {cta}")
+    cta = " ".join([f"{float(m.gen_test_accuracy):.04f}" for m in models])
+    log_string(f"current_gen_test_accuracies {cta}")
 
     ##################################################
 
     for model in models:
         if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
             log_string(
-                f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}"
+                f"storing_gen model {model.id} accuracy {model.gen_test_accuracy} -> {model.test_accuracy}"
             )
-            model.best_dict = copy.deepcopy(model.state_dict())
-            model.best_test_accuracy = model.test_accuracy
+            model.gen_state_dict = copy.deepcopy(model.state_dict())
+            model.gen_test_accuracy = model.test_accuracy
 
     # we restart
     if total_time_generating_c_quizzes == 0:
         total_time_training_models = 0
 
     if (
-        min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
+        min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
         and total_time_training_models >= total_time_generating_c_quizzes
     ):
+        ######################################################################
+        # Re-initalize if there are enough culture quizzes
+
+        if args.reboot:
+            nb_c_quizzes_per_model = [
+                sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models
+            ]
+
+            m = max(nb_c_quizzes_per_model)
+
+            if m >= args.nb_train_samples:
+                model = models[nb_c_quizzes_per_model.index(m)]
+                common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
+                nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
+                log_string(
+                    f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
+                )
+
+                models = create_models()
+                total_time_generating_c_quizzes = 0
+                total_time_training_models = 0
+
         for model in models:
             model.current_dict = copy.deepcopy(model.state_dict())
-            model.load_state_dict(model.best_dict)
+            model.load_state_dict(model.gen_state_dict)
 
         start_time = time.perf_counter()
+
         record_new_c_quizzes(
             models,
             quiz_machine,
             args.nb_new_c_quizzes_for_train,
             args.nb_new_c_quizzes_for_test,
         )
+
         total_time_generating_c_quizzes += time.perf_counter() - start_time
 
-        # Force one epoch of training
         for model in models:
             model.load_state_dict(model.current_dict)
 
@@ -1204,9 +1171,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         total_time_training_models += time.perf_counter() - start_time
 
-        for model in weakest_models:
-            save_additional_results(n_epoch, model, models, c_quizzes_procedure)
-
     # Save the models to disk
 
     for model in models:
@@ -1216,8 +1180,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
                 "test_accuracy": model.test_accuracy,
-                "best_test_accuracy": model.best_test_accuracy,
-                "best_dict": model.best_dict,
+                "gen_test_accuracy": model.gen_test_accuracy,
+                "gen_state_dict": model.gen_state_dict,
                 "train_c_quiz_bags": model.train_c_quiz_bags,
                 "test_c_quiz_bags": model.test_c_quiz_bags,
             },