Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 16:20:01 +0000 (18:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 16:20:01 +0000 (18:20 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 6c4099f..7ba5193 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -48,6 +48,10 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
+parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+
+parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
@@ -78,7 +82,7 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
 
 parser.add_argument("--proba_understands", type=float, default=0.99)
 
@@ -366,20 +370,28 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 ######################################################################
 
+# This is the key routine that decides what generated quizzes to keep
+
+
+def compute_valid_quizzes(token_logprobas):
+    warnings.warn("validation with uniform constraints", RuntimeWarning)
+    l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
+    return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
+
 
-def standard_validity(logproba):
-    l = logproba.sort(dim=-1).values
+def compute_valid_quizzes_(token_logprobas):
+    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
     return (l[:, 0] < math.log(args.proba_not_understands)) & (
         l[:, 1] > math.log(args.proba_understands)
     )
 
 
-def valid_quizzes_and_logprobas(recorded, criteria):
+def extract_valid_quizzes_and_logprobas(recorded):
     validated_quizzes, validated_logprobas = [], []
-    for q, lp in recorded:
-        validated_indices = criteria(lp)
-        validated_quizzes.append(q[validated_indices])
-        validated_logprobas.append(lp[validated_indices])
+    for quizzes, token_logprobas in recorded:
+        validated_indices = compute_valid_quizzes(token_logprobas)
+        validated_quizzes.append(quizzes[validated_indices])
+        validated_logprobas.append(token_logprobas[validated_indices])
 
     if len(validated_quizzes) > 0:
         return torch.cat(validated_quizzes, dim=0), torch.cat(
@@ -411,12 +423,13 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
         c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
         if c_quizzes.size(0) > 0:
-            logproba = quiz_machine.logproba_of_solutions(models, c_quizzes)
-            recorded_quizzes_logprobas.append((c_quizzes, logproba))
+            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
+            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
 
-            validated_quizzes, validated_logprobas = valid_quizzes_and_logprobas(
-                recorded_quizzes_logprobas, standard_validity
-            )
+            (
+                validated_quizzes,
+                validated_logprobas,
+            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
 
             if validated_quizzes is not None:
                 nb_validated = validated_quizzes.size(0)
@@ -433,19 +446,6 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
         validated_quizzes[nb_for_train:nb_to_create], for_train=False
     )
 
-    ######################################################################
-    # save the log probas
-
-    file_name = os.path.join(
-        args.result_dir, f"culture_c_quiz_all_{n_epoch:04d}_logp.dat"
-    )
-
-    with open(file_name, "w") as logp_file:
-        for _, ll in recorded_quizzes_logprobas:
-            for l in ll:
-                s = " ".join([str(x.item()) for x in l])
-                logp_file.write(s + "\n")
-
     ######################################################################
     # save images with their logprobas
 
@@ -454,12 +454,12 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
-
-        file_name = os.path.join(args.result_dir, prefix + "_logp.dat")
-        with open(file_name, "w") as logp_file:
-            for l in vl:
-                s = " ".join([str(x.item()) for x in l])
-                logp_file.write(s + "\n")
+        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
+        torch.save(vl, filename)
+        with open(file_name, "w") as logp_file:
+        # for l in vl:
+        # s = " ".join([str(x.item()) for x in l])
+        # logp_file.write(s + "\n")
 
         quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
 
@@ -574,11 +574,14 @@ if args.max_percents_of_test_in_train >= 0:
 
 ######################################################################
 
-nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+if args.nb_new_c_quizzes_for_train is None:
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+
+if args.nb_new_c_quizzes_for_test is None:
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
 
 log_string(
-    f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+    f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
 )
 
 ######################################################################
@@ -586,12 +589,8 @@ log_string(
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
     args.nb_gpts = 2
-    nb_new_c_quizzes_for_train = 100
-    nb_new_c_quizzes_for_test = 10
-
-    def standard_validity(logproba):
-        l = logproba.sort(dim=-1).values
-        return l[:, 0] < math.log(0.5)
+    args.nb_new_c_quizzes_for_train = 100
+    args.nb_new_c_quizzes_for_test = 10
 
 
 ######################################################################
@@ -602,6 +601,22 @@ for n_epoch in range(args.nb_epochs):
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
+    ##################################################
+    # If all the models are good enough, generate new quizzes and
+    # re-compute the test errors
+
+    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+        create_c_quizzes(
+            models,
+            quiz_machine,
+            nb_for_train=args.nb_new_c_quizzes_for_train,
+            nb_for_test=args.nb_new_c_quizzes_for_test,
+        )
+
+        filename = "c_quizzes.pth"
+        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
+        log_string(f"wrote {filename}")
+
     ##################################################
     # Select, improve, and eval the worst model
 
@@ -640,20 +655,5 @@ for n_epoch in range(args.nb_epochs):
     for model in weakest_models:
         quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
 
-    ##################################################
-    # If all the models are good enough, generate new quizzes and
-    # re-compute the test errors
-
-    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
-            models,
-            quiz_machine,
-            nb_for_train=nb_new_c_quizzes_for_train,
-            nb_for_test=nb_new_c_quizzes_for_test,
-        )
-
-        filename = "c_quizzes.pth"
-        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
-        log_string(f"wrote {filename}")
 
 ######################################################################
index c49ecf2..bc468d3 100755 (executable)
@@ -450,9 +450,13 @@ class QuizMachine:
 
     ######################################################################
 
-    def logproba_of_solutions(self, models, c_quizzes):
+    def solution_token_logprobas(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+            c_quizzes.size(0),
+            len(models),
+            c_quizzes.size(1),
+            device=self.device,
+            dtype=torch.float32,
         )
 
         for model in models:
@@ -466,11 +470,12 @@ class QuizMachine:
                     input = input.to(self.device)
                     ar_mask = self.make_ar_mask(input)
                     output = model(mygpt.BracketedSequence(input)).x
-                    ce = (
-                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    l[:, model.id] = (
+                        -F.cross_entropy(
+                            output.transpose(1, 2), input, reduction="none"
+                        )
                         * ar_mask
                     )
-                    l[:, model.id] = -ce.sum(dim=-1)
 
                 model.train(t)