Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 17:41:04 +0000 (19:41 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 17:41:04 +0000 (19:41 +0200)
main.py

diff --git a/main.py b/main.py
index 455aa1c..19a3c29 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -49,6 +49,8 @@ parser.add_argument("--batch_size", type=int, default=None)
 
 parser.add_argument("--physical_batch_size", type=int, default=None)
 
+parser.add_argument("--inference_batch_size", type=int, default=None)
+
 parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
@@ -157,6 +159,7 @@ assert not args.grids_science_tasks or (
 default_args = {
     "model": "37M",
     "batch_size": 25,
+    "inference_batch_size": 100,
     "nb_train_samples": 100000,
     "nb_test_samples": 10000,
 }
@@ -336,7 +339,7 @@ if not args.resume:
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
-    batch_size=args.physical_batch_size,
+    batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
     device=main_device,
@@ -670,6 +673,106 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 ######################################################################
 
 
+def train_complexifier(model_gen, model_pred1, model_pred2):
+    samples = []
+    perf = []
+
+    optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    for n_epoch in range(args.nb_epochs):
+        for b in range(args.nb_train_samples // args.batch_size):
+            while sum([x.size(0) for x in samples]) < args.batch_size:
+                c_quizzes = quiz_machine.generate_c_quizzes(
+                    args.inference_batch_size,
+                    model_for_generation=model_gen,
+                    procedure=c_quizzes_procedure,
+                )
+                to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+                c_quizzes = c_quizzes[to_keep]
+                if c_quizzes.size(0) > 0:
+                    seq_logproba = quiz_machine.models_logprobas(
+                        [model_pred1, model_pred2],
+                        c_quizzes,
+                        ("A", "f_A", "B", "f_B"),
+                        (0, 0, 0, 1),
+                    ) + quiz_machine.models_logprobas(
+                        [model_pred1, model_pred2],
+                        c_quizzes,
+                        ("f_A", "A", "f_B", "B"),
+                        (0, 0, 0, 1),
+                    )
+                    probas = seq_logproba.exp()
+                    to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
+                        probas[:, model_pred2.id] <= args.proba_not_understands
+                    )
+                    log_string(
+                        f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
+                    )
+                    c_quizzes = c_quizzes[to_keep]
+                    if c_quizzes.size(0):
+                        samples.append(c_quizzes)
+
+            log_string(f"full batch {sum([x.size(0) for x in samples])}")
+
+            x = torch.cat(samples, dim=0)
+
+            input = x[: args.batch_size]
+            samples = [x[args.batch_size :]]
+
+            # -------------------
+
+            seq_logproba = quiz_machine.models_logprobas(
+                [model_pred1, model_pred2],
+                input,
+                ("A", "f_A", "B", "f_B"),
+                (0, 0, 0, 1),
+            ) + quiz_machine.models_logprobas(
+                [model_pred1, model_pred2],
+                input,
+                ("f_A", "A", "f_B", "B"),
+                (0, 0, 0, 1),
+            )
+
+            comments = []
+
+            for l in seq_logproba:
+                comments.append(
+                    f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+                )
+
+            filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir, filename, input, comments=comments
+            )
+            log_string(f"wrote {filename}")
+
+            # ------------------------
+
+            input = input.to(main_device)
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.zero_grad()
+
+            output = model_gen(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_train_loss += loss.item() * input.size(0)
+            nb_train_samples += input.size(0)
+
+            loss.backward()
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.step()
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+        log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+
+
+######################################################################
+
+
 def train_autoencoder():
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
@@ -769,9 +872,9 @@ def train_autoencoder():
     return model
 
 
-if args.autoencoder_dim > 0:
-    ae = train_autoencoder()
-    exit(0)
+if args.autoencoder_dim > 0:
+# ae = train_autoencoder()
+# exit(0)
 
 ######################################################################
 
@@ -864,6 +967,14 @@ if args.dirty_debug:
 
 ######################################################################
 
+# DIRTY TEST
+
+# train_complexifier(models[0], models[1], models[2])
+
+# exit(0)
+
+######################################################################
+
 for n_epoch in range(current_epoch, args.nb_epochs):
     state = {"current_epoch": n_epoch}
     filename = "state.pth"