Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 21:09:44 +0000 (23:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 21:09:44 +0000 (23:09 +0200)
main.py
problem.py

diff --git a/main.py b/main.py
index 7aeae98..50e34a8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -675,6 +675,166 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 ######################################################################
 
 
+def generate_c_quizz_with_generator(generator, quiz_machine, nb):
+    generator.to(main_device)
+
+    c_quizzes = quiz_machine.problem.create_empty_quizzes(
+        nb, struct=("A", "f_A", "B", "f_B")
+    )
+
+    i = F.one_hot(
+        torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+        num_classes=args.nb_gpts,
+    )
+
+    prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
+    len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1)
+
+    prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device)
+
+    T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[
+        None, :
+    ]
+
+    ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long()
+
+    seq_logproba = torch.zeros(
+        prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
+    )
+
+    with torch.autograd.no_grad():
+        t = generator.training
+        generator.eval()
+
+        one_batch_masked_inplace_autoregression(
+            generator,
+            prologued_c_quizzes,
+            ar_mask,
+            seq_logproba,
+            deterministic_synthesis=False,
+        )
+
+        generator.train(t)
+
+    prologued_c_quizzes = (
+        prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
+    )
+
+    return prologued_c_quizzes[:, len_prolog:].to("cpu")
+
+
+def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
+    samples = []
+
+    for _ in range(args.nb_train_samples // args.batch_size):
+        while sum([x.size(0) for x in samples]) < args.batch_size:
+            # Generate a bunch of quizzes
+
+            if w_quizzes:
+                # Either we start with the world quizzes
+                c_quizzes = quiz_machine.problem.generate_w_quizzes(
+                    args.batch_size, progress_bar=False
+                )
+            else:
+                # Or we use the generator itself to generate them
+                c_quizzes = generate_c_quizz_with_generator(
+                    args.batch_size, generator, quiz_machine
+                )
+
+            # We remove the trivial ones
+            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+            c_quizzes = c_quizzes[to_keep]
+
+            # If there are remaining ones, we compute the true prolog
+            # that indicates how the GPTs solve it
+
+            if c_quizzes.size(0) > 0:
+                seq_logproba = quiz_machine.models_logprobas(
+                    models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+                ) + quiz_machine.models_logprobas(
+                    models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+                )
+
+                probas = seq_logproba.exp()
+
+                nu = probas <= args.proba_not_understands
+                u = probas >= args.proba_understands
+
+                prolog = (
+                    (nu.long() * token_prolog_0)
+                    + (((nu == False) & (u == False)).long() * token_prolog_1)
+                    + (u.long() * token_prolog_2)
+                )
+
+                prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+
+                # nb_u = u.long().sum(dim=1)
+                # nb_nu = nu.long().sum(dim=1)
+
+                # prologued_c_quizzes = prologued_c_quizzes[
+                # (nb_u + nb_nu == args.nb_gpts)
+                # & (nb_nu >= 1)
+                # & (nb_nu <= args.max_fail_to_validate)
+                # ]
+
+                samples.append(prologued_c_quizzes)
+
+        # Now we yield a batch
+
+        x = torch.cat(samples, dim=0)
+        samples = [x[args.batch_size :]]
+
+        yield x[: args.batch_size]
+
+
+def one_generator_epoch(
+    generator, quiz_machine, models, w_quizzes=True, local_device=main_device
+):
+    model.to(local_device).train()
+
+    optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    hard_w_quizzes = []
+
+    src = batches_for_generator(
+        generator=generator, quiz_machine=quiz_machine, models=models
+    )
+
+    for input in tqdm.tqdm(
+        src,
+        dynamic_ncols=True,
+        desc="training",
+        total=args.nb_train_samples // args.batch_size,
+    ):
+        input = input.to(local_device)
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.zero_grad()
+
+        targets = input
+
+        output = generator(mygpt.BracketedSequence(input)).x
+        loss = F.cross_entropy(output.transpose(1, 2), targets)
+        acc_train_loss += loss.item() * input.size(0)
+        nb_train_samples += input.size(0)
+
+        loss.backward()
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.step()
+
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+    log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
+
+    generator.to(main_device)
+
+
+######################################################################
+
+
 def train_complexifier(model_gen, model_pred1, model_pred2):
     samples = []
     perf = []
@@ -804,170 +964,6 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
-token_prolog_0 = vocabulary_size + 0
-token_prolog_1 = vocabulary_size + 1
-token_prolog_2 = vocabulary_size + 2
-generator_vocabulary_size = vocabulary_size + 3
-
-generator = mygpt.MyGPT(
-    vocabulary_size=generator_vocabulary_size,
-    dim_model=args.dim_model,
-    dim_keys=args.dim_keys,
-    dim_hidden=args.dim_hidden,
-    nb_heads=args.nb_heads,
-    nb_blocks=args.nb_blocks,
-    causal=True,
-    dropout=args.dropout,
-).to(main_device)
-
-generator.main_test_accuracy = 0.0
-
-
-######################################################################
-
-
-def generate_c_quizz_with_generator(generator, quiz_machine):
-    c_quizzes = quiz_machine.problem.create_empty_quizzes(
-        args.batch_size, struct=("A", "f_A", "B", "f_B")
-    )
-    i = F.one_hot(
-        torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
-        num_classes=args.nb_gpts,
-    )
-    prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
-    c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
-    ar_mask = (
-        torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :]
-        >= args.nb_gpts
-    ).long()
-
-    one_batch_masked_inplace_autoregression(
-        generator,
-        c_quizzes,
-        ar_mask,
-        seq_logproba,
-        deterministic_synthesis=False,
-    )
-
-    return c_quizzes[:, args.nb_gpts :]
-
-
-def batches_for_generator(generator=None, quiz_machine=None, device=main_device):
-    samples = []
-
-    for _ in range(args.nb_train_samples // args.batch_size):
-        while sum([x.size(0) for x in samples]) < args.batch_size:
-            # Generate a bunch of quizzes
-
-            if generator is None:
-                # Either we start with the world quizzes
-                c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size)
-            else:
-                # Or we use the generator itself to generate them
-                c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine)
-
-            # We remove the trivial ones
-            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-            c_quizzes = c_quizzes[to_keep]
-
-            # If there are remaining ones, we compute the true prolog
-            # that indicates how the GPTs solve it
-
-            if c_quizzes.size(0) > 0:
-                seq_logproba = quiz_machine.models_logprobas(
-                    models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-                ) + quiz_machine.models_logprobas(
-                    models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
-                )
-
-                probas = seq_logproba.exp()
-
-                nu = probas <= args.proba_not_understands
-                u = probas >= args.proba_understands
-
-                prolog = (
-                    (nu.long() * token_prolog_0)
-                    + (u.long() * token_prolog_2)
-                    + ((nu == False & u == False).long() * token_prolog_1)
-                )
-
-                samples.append(torch.cat([prolog, c_quizzes], dim=1))
-
-        # Now we yield a batch
-
-        x = torch.cat(samples, dim=0)
-        samples = [x[args.batch_size :]]
-
-        yield x[: args.batch_size]
-
-
-def one_generator_epoch(
-    generator, quiz_machine=None, models=None, local_device=main_device
-):
-    model.to(local_device).train()
-
-    optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
-
-    nb_train_samples, acc_train_loss = 0, 0.0
-
-    hard_w_quizzes = []
-
-    full_input, full_from_w = quiz_machine.data_input(generator, split="train")
-    src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
-
-    for input, from_w in tqdm.tqdm(
-        src,
-        dynamic_ncols=True,
-        desc="training",
-        total=full_input.size(0) // args.batch_size,
-    ):
-        input = input.to(local_device)
-
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.zero_grad()
-
-        targets = input
-
-        output = generator(mygpt.BracketedSequence(input)).x
-        loss_per_token = F.cross_entropy(
-            output.transpose(1, 2), targets, reduction="none"
-        )
-        loss = loss_per_token.mean()
-        acc_train_loss += loss.item() * input.size(0)
-
-        loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        if from_w.any():
-            hard_w_quizzes.append(
-                (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
-            )
-
-        nb_train_samples += input.size(0)
-
-        loss.backward()
-
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.step()
-
-    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
-    log_string(
-        f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}"
-    )
-
-    run_tests(generator, quiz_machine)
-
-    threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
-    threshold = threshold[threshold.size(0) // 2]
-
-    generator.hard_w_quizzes = torch.cat(
-        [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
-    )
-
-    generator.to(main_device)
-
-
-######################################################################
-
 current_epoch = 0
 
 if args.resume:
@@ -1033,6 +1029,59 @@ if args.dirty_debug:
 
 # exit(0)
 
+######################################################################
+
+token_prolog_0 = vocabulary_size + 0
+token_prolog_1 = vocabulary_size + 1
+token_prolog_2 = vocabulary_size + 2
+generator_vocabulary_size = vocabulary_size + 3
+
+generator = mygpt.MyGPT(
+    vocabulary_size=generator_vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_blocks=args.nb_blocks,
+    causal=True,
+    dropout=args.dropout,
+).to(main_device)
+
+generator.main_test_accuracy = 0.0
+
+for n_epoch in range(25):
+    one_generator_epoch(
+        generator,
+        quiz_machine=quiz_machine,
+        models=models,
+        w_quizzes=True,
+        local_device=main_device,
+    )
+
+    c_quizzes = generate_c_quizz_with_generator(
+        generator, quiz_machine, args.batch_size
+    )
+
+    seq_logproba = quiz_machine.models_logprobas(
+        models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+    ) + quiz_machine.models_logprobas(
+        models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+    )
+
+    print(seq_logproba.exp())
+
+
+one_generator_epoch(
+    generator,
+    quiz_machine=quiz_machine,
+    models=models,
+    w_quizzes=False,
+    local_device=main_device,
+)
+
+exit(0)
+
+
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
index 50376d6..9bee5b2 100755 (executable)
@@ -30,7 +30,7 @@ class Problem:
             quizzes = self.generate_w_quizzes_(self.chunk_size)
             self.queue.put(quizzes.to("cpu"), block=True)
 
-    def generate_w_quizzes(self, nb):
+    def generate_w_quizzes(self, nb, progress_bar=True):
         if self.queue is None:
             return self.generate_w_quizzes_(nb)
 
@@ -43,16 +43,22 @@ class Problem:
 
         n = sum([q.size(0) for q in quizzes])
 
-        with tqdm.tqdm(
-            total=nb,
-            dynamic_ncols=True,
-            desc="world generation",
-        ) as pbar:
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
             while n < nb:
                 q = self.queue.get(block=True)
                 quizzes.append(q)
                 n += q.size(0)
-                pbar.update(q.size(0))
 
         quizzes = torch.cat(quizzes, dim=0)
         assert n == quizzes.size(0)