Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 19:34:17 +0000 (21:34 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 19:34:17 +0000 (21:34 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 19a3c29..7aeae98 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -19,6 +19,8 @@ import ffutils
 import mygpt
 import sky, grids, quiz_machine
 
+from quiz_machine import one_batch_masked_inplace_autoregression
+
 import threading, subprocess
 
 import torch.multiprocessing as mp
@@ -773,7 +775,11 @@ def train_complexifier(model_gen, model_pred1, model_pred2):
 ######################################################################
 
 
-def train_autoencoder():
+models = []
+
+for k in range(args.nb_gpts):
+    log_string(f"creating model {k} and its w_quizzes")
+
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -781,130 +787,184 @@ def train_autoencoder():
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        causal=False,
+        causal=True,
         dropout=args.dropout,
-        autoencoder_dim=args.autoencoder_dim,
     ).to(main_device)
 
-    test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+    model.main_test_accuracy = 0.0
+    model.id = k
 
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+        args.nb_train_samples
+    )
 
-    nb_train_samples, acc_train_loss = 0, 0.0
+    model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
 
-    for n_epoch in range(args.nb_epochs):
-        train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-        for input in tqdm.tqdm(
-            train_w_quizzes.split(args.batch_size),
-            dynamic_ncols=True,
-            desc="training AE",
-            total=train_w_quizzes.size(0) // args.batch_size,
-        ):
-            model.train()
-            l = input.size(1) // 4
-            input = input[:, -l:].to(main_device)
+    models.append(model)
 
-            if nb_train_samples % args.batch_size == 0:
-                optimizer.zero_grad()
+######################################################################
 
-            z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
-            output = model.decode(z_shape).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
-            acc_train_loss += loss.item() * input.size(0)
+token_prolog_0 = vocabulary_size + 0
+token_prolog_1 = vocabulary_size + 1
+token_prolog_2 = vocabulary_size + 2
+generator_vocabulary_size = vocabulary_size + 3
 
-            nb_train_samples += input.size(0)
+generator = mygpt.MyGPT(
+    vocabulary_size=generator_vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_blocks=args.nb_blocks,
+    causal=True,
+    dropout=args.dropout,
+).to(main_device)
 
-            loss.backward()
+generator.main_test_accuracy = 0.0
 
-            if nb_train_samples % args.batch_size == 0:
-                optimizer.step()
 
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+######################################################################
 
-        log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
 
-        filename = f"autoencoder.pth"
-        torch.save(
-            model.state_dict(),
-            os.path.join(args.result_dir, filename),
-        )
-        log_string(f"wrote {filename}")
+def generate_c_quizz_with_generator(generator, quiz_machine):
+    c_quizzes = quiz_machine.problem.create_empty_quizzes(
+        args.batch_size, struct=("A", "f_A", "B", "f_B")
+    )
+    i = F.one_hot(
+        torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+        num_classes=args.nb_gpts,
+    )
+    prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
+    c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+    ar_mask = (
+        torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :]
+        >= args.nb_gpts
+    ).long()
+
+    one_batch_masked_inplace_autoregression(
+        generator,
+        c_quizzes,
+        ar_mask,
+        seq_logproba,
+        deterministic_synthesis=False,
+    )
 
-        with torch.autograd.no_grad():
-            model.eval()
-            input = test_w_quizzes[0 * 128 : 1 * 128, -l:]
-            z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
-            logits = model.decode(z_shape).x
+    return c_quizzes[:, args.nb_gpts :]
 
-            # dist = torch.distributions.categorical.Categorical(logits=logits)
-            # q = dist.sample()
 
-            q = logits.argmax(dim=-1)
-            q = q.reshape(q.size(0) // 2, 2, -1)
-            input = input.reshape(input.size(0) // 2, 2, -1)
-            q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                f"culture_ae_{n_epoch:04d}.png",
-                q,
-            )
+def batches_for_generator(generator=None, quiz_machine=None, device=main_device):
+    samples = []
 
-            input1 = test_w_quizzes[1 * 128 : 2 * 128, -l:]
-            input2 = test_w_quizzes[2 * 128 : 3 * 128, -l:]
-            z_shape1 = model.encode(mygpt.BracketedSequence(input1.to(main_device)))
-            z_shape2 = model.encode(mygpt.BracketedSequence(input2.to(main_device)))
-            z_shape = ((z_shape1[0] + z_shape2[0]) * 0.5, z_shape1[1])
-            logits = model.decode(z_shape).x
+    for _ in range(args.nb_train_samples // args.batch_size):
+        while sum([x.size(0) for x in samples]) < args.batch_size:
+            # Generate a bunch of quizzes
 
-            q = logits.argmax(dim=-1)
-            # q = q.reshape(q.size(0) // 2, 2, -1)
-            # input = input.reshape(input.size(0) // 2, 2, -1)
-            # q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
+            if generator is None:
+                # Either we start with the world quizzes
+                c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size)
+            else:
+                # Or we use the generator itself to generate them
+                c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine)
 
-            q = q.reshape(q.size(0) // 4, -1)
+            # We remove the trivial ones
+            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+            c_quizzes = c_quizzes[to_keep]
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                f"culture_mix_ae_{n_epoch:04d}.png",
-                q,
-            )
+            # If there are remaining ones, we compute the true prolog
+            # that indicates how the GPTs solve it
 
-    return model
+            if c_quizzes.size(0) > 0:
+                seq_logproba = quiz_machine.models_logprobas(
+                    models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+                ) + quiz_machine.models_logprobas(
+                    models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+                )
 
+                probas = seq_logproba.exp()
 
-# if args.autoencoder_dim > 0:
-# ae = train_autoencoder()
-# exit(0)
+                nu = probas <= args.proba_not_understands
+                u = probas >= args.proba_understands
 
-######################################################################
+                prolog = (
+                    (nu.long() * token_prolog_0)
+                    + (u.long() * token_prolog_2)
+                    + ((nu == False & u == False).long() * token_prolog_1)
+                )
 
+                samples.append(torch.cat([prolog, c_quizzes], dim=1))
 
-models = []
+        # Now we yield a batch
 
-for k in range(args.nb_gpts):
-    log_string(f"creating model {k} and its w_quizzes")
+        x = torch.cat(samples, dim=0)
+        samples = [x[args.batch_size :]]
 
-    model = mygpt.MyGPT(
-        vocabulary_size=vocabulary_size,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        causal=True,
-        dropout=args.dropout,
-    ).to(main_device)
+        yield x[: args.batch_size]
 
-    model.main_test_accuracy = 0.0
-    model.id = k
 
-    model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
-        args.nb_train_samples
+def one_generator_epoch(
+    generator, quiz_machine=None, models=None, local_device=main_device
+):
+    model.to(local_device).train()
+
+    optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    hard_w_quizzes = []
+
+    full_input, full_from_w = quiz_machine.data_input(generator, split="train")
+    src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+
+    for input, from_w in tqdm.tqdm(
+        src,
+        dynamic_ncols=True,
+        desc="training",
+        total=full_input.size(0) // args.batch_size,
+    ):
+        input = input.to(local_device)
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.zero_grad()
+
+        targets = input
+
+        output = generator(mygpt.BracketedSequence(input)).x
+        loss_per_token = F.cross_entropy(
+            output.transpose(1, 2), targets, reduction="none"
+        )
+        loss = loss_per_token.mean()
+        acc_train_loss += loss.item() * input.size(0)
+
+        loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
+        if from_w.any():
+            hard_w_quizzes.append(
+                (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
+            )
+
+        nb_train_samples += input.size(0)
+
+        loss.backward()
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.step()
+
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+    log_string(
+        f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}"
     )
 
-    model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+    run_tests(generator, quiz_machine)
+
+    threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
+    threshold = threshold[threshold.size(0) // 2]
+
+    generator.hard_w_quizzes = torch.cat(
+        [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
+    )
+
+    generator.to(main_device)
 
-    models.append(model)
 
 ######################################################################
 
index 90879ce..bad05ec 100755 (executable)
@@ -84,7 +84,7 @@ class QuizMachine:
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
             (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
             (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
-            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
         ]
 
         self.LOCK_C_QUIZZES = threading.Lock()