Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 06:57:27 +0000 (08:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 06:57:27 +0000 (08:57 +0200)
main.py

diff --git a/main.py b/main.py
index 6cbb2c4..4488a70 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -383,7 +383,7 @@ data_structures = [
 
 def masked_cross_entropy(output, targets, masks):
     loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
-    return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
+    return (loss_per_token * masks).mean()
 
 
 ######################################################################
@@ -492,6 +492,8 @@ def prioritized_rand(low):
 
 
 def generate(model, nb, local_device=main_device):
+    model.eval().to(local_device)
+
     all_input = quiz_machine.pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
 
@@ -622,7 +624,12 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         correct_parts=correct_parts[:128],
     )
 
-    model.test_accuracy = correct.sum() / quizzes.size(0)
+    nb_correct, nb_total = correct.sum(), quizzes.size(0)
+    model.test_accuracy = nb_correct / nb_total
+
+    log_string(
+        f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)"
+    )
 
     # generate
 
@@ -634,6 +641,22 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     )
 
 
+######################################################################
+
+
+class TokenCat(nn.Module):
+    def __init__(self, m, n):
+        super().__init__()
+        self.m = m
+        self.n = n
+
+    def forward(self, x):
+        u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1)
+        u = self.m(u)
+        u = u[:, self.n :]
+        return u
+
+
 ######################################################################
 
 import attae
@@ -651,6 +674,9 @@ for i in range(args.nb_models):
         dropout=args.dropout,
     ).to(main_device)
 
+    # if i < args.nb_models//2:
+    # model = TokenCat(model, 10)
+
     # model = torch.compile(model)
 
     model.id = i
@@ -748,7 +774,7 @@ def quiz_validation_(
 ######################################################################
 
 
-def generate_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes_(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
 
     def copy_for_inference(model):
@@ -842,6 +868,41 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 ######################################################################
 
 
+def generate_c_quizzes(models, nb, local_device=main_device):
+    record = []
+    nb_validated = 0
+    while nb_validated < nb:
+        model = models[torch.randint(len(models), (1,)).item()]
+        model = copy.deepcopy(model).to(local_device).eval()
+        generator_id = model.id
+
+        c_quizzes = generate(
+            moel=copy_for_inference(model),
+            nb=args.physical_batch_size,
+            local_device=local_device,
+        )
+
+        nb_correct, nb_wrong = 0, 0
+        for i, model in enumerate(models):
+            model = copy.deepcopy(model).to(local_device).eval()
+            result = predict_full(model, c_quizzes, local_device=local_device)
+            nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+            nb_correct += (nb_mistakes == 0).long()
+            nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
+
+        to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+            nb_wrong >= args.nb_have_to_be_wrong
+        )
+
+        nb_validated += to_keep.long().sum()
+        record.append(c_quizzes[to_keep])
+
+    return torch.cat(record)
+
+
+######################################################################
+
+
 def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False):
     l = []
 
@@ -1094,17 +1155,15 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
-    # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
-
     multithread_execution(
         one_complete_epoch,
         [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
     )
 
-    # --------------------------------------------------------------------
-
     save_models(models)
 
+    # --------------------------------------------------------------------
+
     duration = time.perf_counter() - start_time
     str_duration = ""
     if duration >= 60: