Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 17 Aug 2024 08:42:42 +0000 (10:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 17 Aug 2024 08:42:42 +0000 (10:42 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 915e10e..92bc05f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -63,6 +63,8 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
+parser.add_argument("--lambda_H", type=float, default=0.0)
+
 parser.add_argument("--schedule_free", action="store_true", default=False)
 
 # ----------------------------------
@@ -404,10 +406,20 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         targets = input
 
         output = model(mygpt.BracketedSequence(input)).x
+
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
         )
+
+        # warnings.warn("entropy masking", RuntimeWarning)
+        # l = output.transpose(1, 2).log_softmax(dim=1)
+        # H = -(l * l.exp()).sum(dim=1)
+        # M = (H >= -math.log(0.99) / H.size(1)).long()
+        # print(H, M)
+        # loss_per_token = loss_per_token * M
+
         loss = (loss_per_token * mask_loss).mean() + model.loss
+
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
@@ -782,98 +794,6 @@ class Thinker(nn.Module):
         return bs
 
 
-if args.test == "func":
-    test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
-
-    L = test_input.size(1) // 4
-    f_len = 50
-
-    model = Thinker(
-        vocabulary_size=vocabulary_size,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        f_len=f_len,
-        dropout=args.dropout,
-    ).to(main_device)
-
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
-    for n_epoch in range(args.nb_epochs):
-        model.train()
-
-        train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-
-        nb_train_samples, acc_train_loss = 0, 0.0
-
-        for input in tqdm.tqdm(
-            train_input.split(args.batch_size),
-            dynamic_ncols=True,
-            desc="training",
-            total=train_input.size(0) // args.batch_size,
-        ):
-            input = input.to(main_device)
-
-            if nb_train_samples % args.batch_size == 0:
-                model.optimizer.zero_grad()
-
-            output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
-            targets = input[:, 3 * L :]
-            loss = F.cross_entropy(output.transpose(1, 2), targets)
-            acc_train_loss += loss.item() * input.size(0)
-
-            nb_train_samples += input.size(0)
-
-            loss.backward()
-
-            if nb_train_samples % args.batch_size == 0:
-                model.optimizer.step()
-
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
-        log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}")
-
-        with torch.autograd.no_grad():
-            model.eval()
-
-            nb_test_samples, acc_test_loss = 0, 0.0
-
-            for input in tqdm.tqdm(
-                test_input.split(args.batch_size),
-                dynamic_ncols=True,
-                desc="testing",
-                total=test_input.size(0) // args.batch_size,
-            ):
-                input = input.to(main_device)
-
-                output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
-                targets = input[:, 3 * L :]
-                loss = F.cross_entropy(output.transpose(1, 2), targets)
-                acc_test_loss += loss.item() * input.size(0)
-
-                nb_test_samples += input.size(0)
-
-            test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-            log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}")
-
-            input = test_input[:128].clone().to(main_device)
-
-            output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
-            dist = torch.distributions.categorical.Categorical(logits=output)
-            input[:, 3 * L + 1 :] = dist.sample()[:, 1:]
-
-            problem.save_quizzes_as_image(
-                args.result_dir,
-                f"thinker_prediction_{n_epoch:04d}.png",
-                quizzes=input,
-                # predicted_parts=predicted_parts,
-                # correct_parts=correct_parts,
-            )
-
-
 ######################################################################
 
 models = []
@@ -913,7 +833,7 @@ for k in range(args.nb_gpts):
 
     model.test_accuracy = 0.0
     model.best_test_accuracy = 0.0
-
+    model.best_dict = copy.deepcopy(model.state_dict())
     models.append(model)
 
 ######################################################################
@@ -1071,25 +991,10 @@ if args.test == "mlp":
     exit(0)
 
 ######################################################################
-######################################################################
 
-if args.test == "reject":
-    record = []
-
-    c_quizzes_procedure = [
-        (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
-        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
-        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
-        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
-        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
-        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-        (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
-        (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
-        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-    ]
 
-    while sum([x.size(0) for x in record]) < 64:
+def save_generated_c_quizzes(model, filename, nb=64):
+    while sum([x.size(0) for x in record]) < nb:
         model = models[torch.randint(len(models), (1,)).item()]
         c_quizzes = quiz_machine.generate_c_quizzes(
             64,
@@ -1118,8 +1023,6 @@ if args.test == "reject":
 
         print("NB_KEPT", sum([x.size(0) for x in record]))
 
-    filename = f"sampling_with_rejection.png"
-
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         filename,
@@ -1128,6 +1031,40 @@ if args.test == "reject":
 
     log_string(f"wrote {filename}")
 
+
+######################################################################
+
+if args.test == "entropy":
+    model = models[0]
+    model.to(main_device)
+
+    log_string("starting testing entropy maximization")
+
+    train_input = quiz_machine.generate_c_quizzes(
+        1000,
+        model_for_generation=model,
+        procedure=c_quizzes_procedure,
+    )
+
+    for n_epoch in range(10):
+        nb_train_samples, acc_train_loss = 0, 0.0
+
+        for input in train_input.split(args.batch_size):
+            input = input.to(main_device)
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = output.log_softmax(dim=1).mean()
+
+            acc_train_loss += loss.item() * input.size(0)
+            nb_train_samples += input.size(0)
+
+            model.optimizer.zero_grad()
+            loss.backward()
+            model.optimizer.step()
+
+        log_string(
+            f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
+        )
+
     exit(0)
 
 ######################################################################
@@ -1187,7 +1124,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    if total_time_training_models < total_time_generating_c_quizzes:
+    if total_time_training_models <= total_time_generating_c_quizzes:
         ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 
         weakest_models = ranked_models[: len(gpus)]
@@ -1212,6 +1149,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         total_time_training_models += time.perf_counter() - start_time
 
+        for model in weakest_models:
+            save_additional_results(n_epoch, model, models, c_quizzes_procedure)
+
     # Save the models to disk
 
     for model in models:
@@ -1230,9 +1170,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
-    for model in weakest_models:
-        save_additional_results(n_epoch, model, models, c_quizzes_procedure)
-
     ######################################################################
 
     if args.log_command is not None:
index 3c4a865..98e0ea5 100755 (executable)
@@ -338,7 +338,8 @@ class QuizMachine:
 
         c_quizzes = None
 
-        for s, m, mt in procedure:
+        for n_step, setup in enumerate(procedure):
+            s, m, mt = setup
             if c_quizzes is None:
                 c_quizzes = self.problem.create_empty_quizzes(nb, s)
                 c_quizzes = c_quizzes.to(self.device)
@@ -354,6 +355,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_quiz_mask(c_quizzes, s, m),
                 seq_logprobas=seq_logprobas,
+                progress_bar_desc=f"autoregression {n_step}/{len(procedure)}",
             )
 
             model_for_generation.reset_transformations()