Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 18:20:30 +0000 (20:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 18:20:30 +0000 (20:20 +0200)
main.py

diff --git a/main.py b/main.py
index 6113813..61fc090 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -68,6 +68,7 @@ parser.add_argument("--learning_rate", type=float, default=5e-4)
 parser.add_argument("--reboot", action="store_true", default=False)
 
 # ----------------------------------
+
 parser.add_argument("--model", type=str, default="37M")
 
 parser.add_argument("--dim_model", type=int, default=None)
@@ -83,6 +84,7 @@ parser.add_argument("--nb_blocks", type=int, default=None)
 parser.add_argument("--dropout", type=float, default=0.5)
 
 # ----------------------------------
+
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 parser.add_argument("--problem", type=str, default="grids")
@@ -103,7 +105,7 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=3)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98)
 
 parser.add_argument("--proba_understands", type=float, default=0.95)
 
@@ -119,8 +121,6 @@ parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
 
-parser.add_argument("--logit_std_max", type=float, default=-1)
-
 ######################################################################
 
 grids_tasks = ", ".join(
@@ -341,275 +341,6 @@ def optimizer_to(optim, device):
                         subparam._grad.data = subparam._grad.data.to(device)
 
 
-######################################################################
-
-
-def run_tests(model, quiz_machine, local_device=main_device):
-    with torch.autograd.no_grad():
-        model.to(local_device).eval()
-
-        nb_test_samples, acc_test_loss = 0, 0.0
-        nb_samples_accumulated = 0
-
-        full_input, _, full_mask_loss = quiz_machine.data_input(
-            args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
-        )
-        src = zip(
-            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
-        )
-
-        for input, mask_loss in tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc="test",
-            total=full_input.size(0) // args.batch_size,
-        ):
-            input = input.to(local_device)
-            mask_loss = mask_loss.to(local_device)
-            targets = input
-
-            output = model(input)
-            loss_per_token = F.cross_entropy(
-                output.transpose(1, 2), targets, reduction="none"
-            )
-            loss = (loss_per_token * mask_loss).mean()
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
-
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-        log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
-
-        input, _, _ = quiz_machine.data_input(
-            2000, model.test_c_quiz_bags, args.c_quiz_multiplier
-        )
-
-        model.test_accuracy = quiz_machine.produce_results(
-            n_epoch=n_epoch,
-            model=model,
-            input=input,
-            result_dir=args.result_dir,
-        )
-
-
-######################################################################
-
-
-def one_epoch(model, quiz_machine, local_device=main_device):
-    model.to(local_device).train()
-    optimizer_to(model.optimizer, local_device)
-
-    nb_train_samples, acc_train_loss = 0, 0.0
-
-    full_input, _, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples,
-        model.train_c_quiz_bags + common_c_quiz_bags,
-        args.c_quiz_multiplier,
-    )
-    src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
-
-    for input, mask_loss in tqdm.tqdm(
-        src,
-        dynamic_ncols=True,
-        desc="training",
-        total=full_input.size(0) // args.batch_size,
-    ):
-        input = input.to(local_device)
-        mask_loss = mask_loss.to(local_device)
-
-        if nb_train_samples % args.batch_size == 0:
-            model.optimizer.zero_grad()
-
-        targets = input
-        output = model(input)
-        loss = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
-        loss = (loss * mask_loss).mean() + model.loss
-
-        acc_train_loss += loss.item() * input.size(0)
-        nb_train_samples += input.size(0)
-
-        loss.backward()
-
-        if nb_train_samples % args.batch_size == 0:
-            model.optimizer.step()
-
-    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
-    log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
-
-    run_tests(model, quiz_machine)
-
-    model.to(main_device)
-    optimizer_to(model.optimizer, main_device)
-
-
-######################################################################
-
-
-def model_modifier_hot(model):
-    model.temperature = args.temperature_hot
-    # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
-
-
-def model_modifier_cold(model):
-    model.temperature = args.temperature_cold
-    # pass
-
-
-c_quizzes_procedure = [
-    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-    # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
-]
-
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
-    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
-    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
-######################################################################
-
-
-def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
-    nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
-    nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
-
-    start_time = time.perf_counter()
-
-    for model in models:
-        model.recorded_c_quizzes = []
-
-    teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
-
-    while nb_validated < nb_to_validate:
-        model_for_generation = models[torch.randint(len(models), (1,)).item()]
-
-        # We generate quizzes with a procedure that injects some
-        # structured noise
-
-        c_quizzes = quiz_machine.generate_c_quizzes(
-            nb_to_generate_per_iteration,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-        )
-
-        nb_generated += c_quizzes.size(0)
-
-        # We discard the trivial ones, according to a criterion
-        # specific to the world quizzes (e.g. B=f(B))
-
-        to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-
-        c_quizzes = c_quizzes[to_keep]
-
-        # Compute the responses of all the models on the c_quizzes,
-        # and their proba estimates of their responses
-
-        solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
-
-        proba_own_solution = torch.zeros(
-            c_quizzes.size(0), len(models), device=solved_c_quizzes.device
-        )
-
-        for model in models:
-            (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
-                model,
-                solved_c_quizzes[:, model.id],
-                quad_orders=("A", "f_A", "B", "f_B"),
-                quad=(0, 0, 0, 1),
-            )
-
-            proba_own_solution[:, model.id] = model_proba_solutions(
-                model, solved_c_quizzes[:, model.id]
-            )
-
-        # Now for every model not confident of its response, we pick
-        # the most consistent from a model which is confident
-
-        for s in range(proba_own_solution.size(0)):
-            # At least one GPT does not understand at all
-            if proba_own_solution[s, :].min() < args.proba_not_understands:
-                dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
-                nb_fails = dont_get_this_quiz.long().sum()
-                # At most max_fail_to_validate do not understand (default 3/5)
-                if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
-                    for model in models:
-                        # If a GPT does not get that quiz
-                        if dont_get_this_quiz[model.id]:
-                            assert (
-                                proba_own_solution[s, model.id] < args.proba_understands
-                            )
-                            # Look at its estimate of the others'solutions
-                            proba_other_solutions = model_proba_solutions(
-                                model, solved_c_quizzes[s]
-                            )
-                            # Randomize a bit the orders for the frequent P=1
-                            proba_other_solutions += (
-                                torch.rand(proba_other_solutions.size()) * 1e-6
-                            )
-                            # Remove the under threshold confidence solutions
-                            proba_other_solutions[dont_get_this_quiz] = -1
-                            i = proba_other_solutions.argmax()
-                            model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
-                            teaching_count[i, model.id] += 1
-                            nb_validated += 1
-
-        duration = time.perf_counter() - start_time
-
-        if nb_validated > 0:
-            if nb_validated < nb_to_validate:
-                d = (nb_to_validate - nb_validated) * duration / nb_validated
-                e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
-                    "%a %H:%M"
-                )
-            else:
-                e = "now!"
-        else:
-            e = "???"
-
-        log_string(
-            f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
-        )
-
-    for s in range(teaching_count.size(0)):
-        o = [x.item() for x in teaching_count[s]]
-        log_string(f"teacher model {s} to {o}")
-
-    for model in models:
-        new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
-
-        if new_bag.size(0) > 0:
-            n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
-            if n > 0:
-                model.train_c_quiz_bags.append(new_bag[:n])
-            if n < new_bag.size(0):
-                model.test_c_quiz_bags.append(new_bag[n:])
-
-            c_quizzes = new_bag[:128]
-
-            l = [model_proba_solutions(model, c_quizzes) for model in models]
-            probas = torch.cat([x[:, None] for x in l], dim=1)
-            comments = []
-
-            for l in probas:
-                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
-            filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir, filename, c_quizzes, comments=comments
-            )
-
-        log_string(
-            f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
-        )
-
-
 ######################################################################
 
 from mygpt import (
@@ -959,6 +690,16 @@ class FunctionalAE(nn.Module):
 
 ######################################################################
 
+# quad_order, quad_generate, quad_noise, quad_loss
+
+data_structures = [
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
+
 
 def ae_batches(
     quiz_machine,
@@ -1305,7 +1046,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device)
+    run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
 
 
 ######################################################################