Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 12:21:09 +0000 (14:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 12:21:09 +0000 (14:21 +0200)
main.py

diff --git a/main.py b/main.py
index eb0f776..2e8ec43 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -93,7 +93,7 @@ parser.add_argument("--gpus", type=str, default="all")
 
 # ----------------------------------
 
-parser.add_argument("--nb_gpts", type=int, default=5)
+parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
@@ -464,6 +464,16 @@ c_quizzes_procedure = [
     # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
+# quad_order, quad_generate, quad_noise, quad_loss
+
+data_structures = [
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
+
 ######################################################################
 
 
@@ -783,6 +793,18 @@ class MyAttentionAE(nn.Module):
         trunk_blocks = []
 
         for b in range(nb_blocks):
+            # if b == nb_blocks//2:
+            # trunk_blocks += [
+            # QKVAttention(
+            # dim_in=dim_model,
+            # dim_qk=dim_keys,
+            # dim_v=dim_model // nb_heads,
+            # nb_heads=nb_heads,
+            # attention_dropout=dropout,
+            # ),
+            # VaswaniPositionalEncoding(len_max=1e5)
+            # ]
+
             trunk_blocks += [
                 WithResidual(
                     CacheWrapper(
@@ -864,20 +886,6 @@ def ae_batches(
             mask_loss.to(local_device),
         )
 
-    # quiz_machine.problem.save_quizzes_as_image(
-    # args.result_dir,
-    # filename="a.png",
-    # quizzes=a,
-    # )
-
-    # quiz_machine.problem.save_quizzes_as_image(
-    # args.result_dir,
-    # filename="b.png",
-    # quizzes=b,
-    # )
-
-    # time.sleep(1000)
-
 
 def NTC_masked_cross_entropy(output, targets, mask):
     loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
@@ -892,6 +900,12 @@ def deterministic(mask_generate):
     return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
 
 
+# This function returns a tensor of same shape as low, full of uniform
+# random values in [0,1], such that the values corresponding to the
+# True in low are all lesser than the values corresponding to the
+# False.
+
+
 def prioritized_rand(low):
     x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
     k = torch.rand(low.size(), device=low.device) + low.long()
@@ -901,17 +915,15 @@ def prioritized_rand(low):
     return y
 
 
-def ae_generate(
-    model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
-):
+def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
-    input = (1 - mask_generate) * input + mask_generate * noise
 
-    proba_erased = noise_proba
+    input = (1 - mask_generate) * input + mask_generate * noise
 
     d = deterministic(mask_generate)[:, None]
+
     changed = True
 
     for it in range(nb_iterations_max):
@@ -922,7 +934,8 @@ def ae_generate(
 
         r = prioritized_rand(final != input)
 
-        mask_erased = mask_generate * (r <= proba_erased).long()
+        mask_erased = mask_generate * (r <= noise_proba).long()
+
         mask_to_change = d * mask_generate + (1 - d) * mask_erased
 
         update = (1 - mask_to_change) * input + mask_to_change * final
@@ -956,56 +969,22 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     return result
 
 
-def test_ae(local_device=main_device):
-    model = MyAttentionAE(
-        vocabulary_size=vocabulary_size,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        dropout=args.dropout,
-    ).to(main_device)
-
-    # quad_order, quad_generate, quad_noise, quad_loss
-
-    data_structures = [
-        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
-        (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
-        (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
-        (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
-        (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-    ]
-
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
-    model.to(local_device).train()
-    optimizer_to(model.optimizer, local_device)
-
-    nb_iterations = 25
-    probs_iterations = torch.arange(nb_iterations, device=main_device)
-    probs_iterations = 0.1 ** (probs_iterations / nb_iterations)
-    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+######################################################################
 
-    for n_epoch in range(args.nb_epochs):
-        # ----------------------
-        # Train
 
-        model.train()
-        nb_train_samples, acc_train_loss = 0, 0.0
+def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
+    with torch.autograd.no_grad():
+        model.eval().to(local_device)
 
-        noise_proba = 0.05
+        nb_test_samples, acc_test_loss = 0, 0.0
 
         for input, mask_generate, mask_loss in ae_batches(
             quiz_machine,
-            args.nb_train_samples,
+            args.nb_test_samples,
             data_structures,
             local_device,
-            "training",
+            "test",
         ):
-            if nb_train_samples % args.batch_size == 0:
-                model.optimizer.zero_grad()
-
             d = deterministic(mask_generate)
             p = probs_iterations.expand(input.size(0), -1)
             dist = torch.distributions.categorical.Categorical(probs=p)
@@ -1013,119 +992,168 @@ def test_ae(local_device=main_device):
             N1 = N0 + 1
             N0 = (1 - d) * N0
             N1 = (1 - d) * N1 + d * nb_iterations
-
             targets, input = degrade_input(
                 input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
             )
-
-            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-            # for n in ["input", "targets"]:
-            # filename = f"{n}.png"
-            # quiz_machine.problem.save_quizzes_as_image(
-            # args.result_dir,
-            # filename,
-            # quizzes=locals()[n],
-            # )
-            # log_string(f"wrote {filename}")
-            # time.sleep(1000)
-            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
             input_with_mask = NTC_channel_cat(input, mask_generate)
             logits = model(input_with_mask)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+            acc_test_loss += loss.item() * input.size(0)
+            nb_test_samples += input.size(0)
 
-            loss.backward()
+        log_string(
+            f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
+        )
 
-            if nb_train_samples % args.batch_size == 0:
-                model.optimizer.step()
+        # -------------------------------------------
+        # Test generation
 
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        nb_correct, nb_total, record = 0, 0, []
 
-        log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}")
+        for input, mask_generate, mask_loss in ae_batches(
+            quiz_machine,
+            args.nb_test_samples,
+            data_structures,
+            local_device,
+            "test",
+        ):
+            targets = input.clone()
+            result = ae_generate(
+                model, (1 - mask_generate) * input, mask_generate, noise_proba
+            )
+            correct = (result == targets).min(dim=1).values.long()
+            predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
+                :, :, 1
+            ]
+            solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
+            correct = (2 * correct - 1) * (solution_is_deterministic).long()
+            nb_correct += (correct == 1).long().sum()
+            nb_total += (correct != 0).long().sum()
+            correct_parts = predicted_parts * correct[:, None]
+            record.append((result, predicted_parts, correct_parts))
 
-        # ----------------------
-        # Test
+        log_string(
+            f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+        )
 
-        with torch.autograd.no_grad():
-            model.eval()
+        model.test_accuracy = nb_correct / nb_total
 
-            nb_test_samples, acc_test_loss = 0, 0.0
+        filename = f"prediction_ae_{n_epoch:04d}.png"
 
-            for input, mask_generate, mask_loss in ae_batches(
-                quiz_machine,
-                args.nb_test_samples,
-                data_structures,
-                local_device,
-                "test",
-            ):
-                d = deterministic(mask_generate)
-                p = probs_iterations.expand(input.size(0), -1)
-                dist = torch.distributions.categorical.Categorical(probs=p)
-                N0 = dist.sample()
-                N1 = N0 + 1
-                N0 = (1 - d) * N0
-                N1 = (1 - d) * N1 + d * nb_iterations
-                targets, input = degrade_input(
-                    input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
-                )
-                input_with_mask = NTC_channel_cat(input, mask_generate)
-                logits = model(input_with_mask)
-                loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
-                acc_test_loss += loss.item() * input.size(0)
-                nb_test_samples += input.size(0)
+        result, predicted_parts, correct_parts = (
+            torch.cat([x[i] for x in record]) for i in [0, 1, 2]
+        )
+
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            filename,
+            quizzes=result,
+            predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
+        )
 
-            log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
+        log_string(f"wrote {filename}")
 
-            # -------------------------------------------
-            # Test generation
 
-            for ns, s in enumerate(data_structures):
-                quad_order, quad_generate, _, _ = s
+######################################################################
 
-                input, mask_generate, _ = next(
-                    ae_batches(quiz_machine, 128, [s], local_device, batch_size=128)
-                )
 
-                targets = input.clone()
-                input = ae_generate(
-                    model,
-                    input,
-                    mask_generate,
-                    n_epoch,
-                    noise_proba=noise_proba,
-                )
+def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device):
+    model.train().to(local_device)
 
-                correct = (input == targets).min(dim=1).values.long()
-                predicted_parts = torch.tensor(quad_generate, device=input.device)
-                predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)
-                solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
-                correct = (2 * correct - 1) * (solution_is_deterministic).long()
-                nb_correct = (correct == 1).long().sum()
-                nb_total = (correct != 0).long().sum()
-                correct_parts = predicted_parts * correct[:, None]
-
-                log_string(
-                    f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
-                )
+    nb_train_samples, acc_train_loss = 0, 0.0
 
-                filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
+    for input, mask_generate, mask_loss in ae_batches(
+        quiz_machine,
+        args.nb_train_samples,
+        data_structures,
+        local_device,
+        "training",
+    ):
+        if nb_train_samples % args.batch_size == 0:
+            model.optimizer.zero_grad()
 
-                quiz_machine.problem.save_quizzes_as_image(
-                    args.result_dir,
-                    filename,
-                    quizzes=input,
-                    predicted_parts=predicted_parts,
-                    correct_parts=correct_parts,
-                )
+        d = deterministic(mask_generate)
+        p = probs_iterations.expand(input.size(0), -1)
+        dist = torch.distributions.categorical.Categorical(probs=p)
+        N0 = dist.sample()
+        N1 = N0 + 1
+        N0 = (1 - d) * N0
+        N1 = (1 - d) * N1 + d * nb_iterations
 
-                log_string(f"wrote {filename}")
+        targets, input = degrade_input(
+            input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+        )
 
+        input_with_mask = NTC_channel_cat(input, mask_generate)
+        logits = model(input_with_mask)
+        loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
+        acc_train_loss += loss.item() * input.size(0)
+        nb_train_samples += input.size(0)
 
-if args.test == "ae":
-    test_ae(local_device=main_device)
-    exit(0)
+        loss.backward()
+
+        if nb_train_samples % args.batch_size == 0:
+            model.optimizer.step()
+
+    log_string(
+        f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
+    )
+
+    run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
+
+
+######################################################################
+
+noise_proba = 0.05
+
+nb_iterations = 25
+probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
+probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+
+models = []
+
+for i in range(args.nb_models):
+    model = MyAttentionAE(
+        vocabulary_size=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        dropout=args.dropout,
+    ).to(main_device)
+
+    model.id = i
+    model.test_accuracy = 0.0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    model.to(main_device).train()
+    optimizer_to(model.optimizer, main_device)
+
+    models.append(model)
+
+for n_epoch in range(args.nb_epochs):
+    ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+    weakest_models = ranked_models[: len(gpus)]
+
+    threads = []
+
+    start_time = time.perf_counter()
+
+    for gpu, model in zip(gpus, weakest_models):
+        log_string(f"training model {model.id}")
+
+        t = threading.Thread(
+            target=one_ae_epoch, daemon=True, args=(model, quiz_machine, n_epoch, gpu)
+        )
+
+        threads.append(t)
+
+        t.start()
+
+        for t in threads:
+            t.join()
 
 ######################################################################
 
@@ -1136,7 +1164,7 @@ def create_models():
     def compute_causal_attzero(t_q, t_k):
         return t_q < t_k
 
-    for k in range(args.nb_gpts):
+    for k in range(args.nb_models):
         log_string(f"creating model {k}")
 
         model = mygpt.MyGPT(
@@ -1244,7 +1272,7 @@ log_string(
 
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
-    args.nb_gpts = 2
+    args.nb_models = 2
     args.nb_new_c_quizzes_for_train = 100
     args.nb_new_c_quizzes_for_test = 10