Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 20:21:03 +0000 (22:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 20:21:03 +0000 (22:21 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 41efc86..ca1e9b5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -314,7 +314,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        for input in quiz_machine.batches(model, split="test"):
+        full_input, _ = quiz_machine.data_input(model, split="test")
+        src = full_input.split(args.batch_size)
+
+        for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
             input = input.to(local_device)
 
             bs = model(mygpt.BracketedSequence(input))
@@ -345,16 +348,29 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in quiz_machine.batches(model, split="train"):
+    hard_w_quizzes = []
+
+    full_input, full_from_w = quiz_machine.data_input(model, split="train")
+    src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+
+    for input, from_w in tqdm.tqdm(src, dynamic_ncols=True, desc="training"):
         input = input.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
 
         output = model(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), input)
+        loss_per_token = F.cross_entropy(
+            output.transpose(1, 2), input, reduction="none"
+        )
+        loss = loss_per_token.mean()
         acc_train_loss += loss.item() * input.size(0)
 
+        loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
+        hard_w_quizzes.append(
+            (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
+        )
+
         nb_train_samples += input.size(0)
 
         loss.backward()
@@ -368,6 +384,13 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     run_tests(model, quiz_machine, deterministic_synthesis=False)
 
+    threshold = torch.cat([x[1] for x in hard_w_quizzes], dim=0).sort().values
+    threshold = threshold[threshold.size(0) // 2]
+
+    model.hard_w_quizzes = torch.cat(
+        [x[0][x[1] >= threshold] for x in hard_w_quizzes], dim=0
+    )
+
     model.to(main_device)
 
 
@@ -443,7 +466,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
 
     validated_quizzes = torch.cat(recorded, dim=0)
@@ -542,54 +565,6 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
-        (0, 1)
-    )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
-
-######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
-    def subsets_as_tuples(batches, cs):
-        s = set()
-        for batch in batches:
-            for x in batch:
-                s.add(tuple([v.item() for v in x]))
-                if len(s) == cs:
-                    yield s
-                    s = set()
-        yield s
-
-    nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(
-        quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
-    ):
-        in_train = set()
-        for train_subset in subsets_as_tuples(
-            quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
-        ):
-            in_train.update(test_subset.intersection(train_subset))
-        nb_in_train += len(in_train)
-        nb_test += len(test_subset)
-
-    log_string(
-        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
-    )
-
-    assert (
-        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
-    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
-
-######################################################################
-
 if args.nb_new_c_quizzes_for_train is None:
     args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
 
@@ -679,7 +654,6 @@ for n_epoch in range(args.nb_epochs):
     for model in weakest_models:
         quiz_machine.renew_w_quizzes(
             model=model,
-            nb=args.nb_train_samples,
             for_train=True,
             forward_only=args.forward_only,
         )
index faa640e..32b3f7e 100755 (executable)
@@ -313,7 +313,7 @@ class QuizMachine:
 
     ######################################################################
 
-    def batches(self, model, split="train", desc=None):
+    def data_input(self, model, split="train"):
         assert split in {"train", "test"}
 
         with self.LOCK_C_QUIZZES:
@@ -335,24 +335,18 @@ class QuizMachine:
                 ]
                 w_quizzes = w_quizzes[i]
 
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = c_quizzes.size(0)
+                quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+                from_w = torch.arange(
+                    quizzes.size(0), device=quizzes.device
+                ) < w_quizzes.size(0)
+                i = torch.randperm(quizzes.size(0), device=quizzes.device)
 
-                input = torch.cat([w_quizzes, c_quizzes], dim=0)
-            else:
-                input = w_quizzes
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = 0
-
-        # Shuffle
-        input = input[torch.randperm(input.size(0))]
+                return quizzes[i], type_w[i]
 
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
+            else:
+                return w_quizzes, torch.full(
+                    (w_quizzes.size(0),), True, device=w_quizzes.device
+                )
 
     ######################################################################
 
@@ -441,14 +435,29 @@ class QuizMachine:
 
     ######################################################################
 
-    def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+    def renew_w_quizzes(self, model, for_train=True, forward_only=False):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
-        nb = min(nb, input.size(0))
-        input[:-nb] = input[nb:].clone()
-        fresh_w_quizzes = self.generate_token_sequences(nb)
-        if not forward_only:
-            self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to("cpu")
+
+        if for_train and hasattr(model, "hard_w_quizzes"):
+            self.logger(
+                f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+            )
+            if model.hard_w_quizzes.size(0) >= input.size(0):
+                input[...] = model.hard_w_quizzes[
+                    torch.randperm(hard_w_quizzes.size(0))[input.size(0)]
+                ]
+            else:
+                input[...] = torch.cat(
+                    [
+                        model.hard_w_quizzes,
+                        self.generate_token_sequences(
+                            input.size(0) - model.hard_w_quizzes.size(0)
+                        ),
+                    ],
+                    dim=0,
+                )
+        else:
+            input[...] = self.generate_token_sequences(input.size(0))
 
     ######################################################################