Merge branch 'dev'
[culture.git] / main.py
diff --git a/main.py b/main.py
index 6c4099f..6b00bbf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -48,6 +48,10 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
+parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+
+parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
@@ -78,13 +82,13 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
 
-parser.add_argument("--proba_understands", type=float, default=0.99)
+parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--generation_temperature", type=float, default=1.0)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -366,20 +370,31 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 ######################################################################
 
+# This is the key routine that decides what generated quizzes to keep
+
+
+# token_logprobas are NxMxT where M is the number of models
+
+
+def compute_valid_quizzes_(token_logprobas):
+    warnings.warn("validation with uniform constraints", RuntimeWarning)
+    l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
+    return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
 
-def standard_validity(logproba):
-    l = logproba.sort(dim=-1).values
+
+def compute_valid_quizzes(token_logprobas):
+    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
     return (l[:, 0] < math.log(args.proba_not_understands)) & (
         l[:, 1] > math.log(args.proba_understands)
     )
 
 
-def valid_quizzes_and_logprobas(recorded, criteria):
+def extract_valid_quizzes_and_logprobas(recorded):
     validated_quizzes, validated_logprobas = [], []
-    for q, lp in recorded:
-        validated_indices = criteria(lp)
-        validated_quizzes.append(q[validated_indices])
-        validated_logprobas.append(lp[validated_indices])
+    for quizzes, token_logprobas in recorded:
+        validated_indices = compute_valid_quizzes(token_logprobas)
+        validated_quizzes.append(quizzes[validated_indices])
+        validated_logprobas.append(token_logprobas[validated_indices])
 
     if len(validated_quizzes) > 0:
         return torch.cat(validated_quizzes, dim=0), torch.cat(
@@ -411,12 +426,13 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
         c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
         if c_quizzes.size(0) > 0:
-            logproba = quiz_machine.logproba_of_solutions(models, c_quizzes)
-            recorded_quizzes_logprobas.append((c_quizzes, logproba))
+            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
+            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
 
-            validated_quizzes, validated_logprobas = valid_quizzes_and_logprobas(
-                recorded_quizzes_logprobas, standard_validity
-            )
+            (
+                validated_quizzes,
+                validated_logprobas,
+            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
 
             if validated_quizzes is not None:
                 nb_validated = validated_quizzes.size(0)
@@ -433,19 +449,6 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
         validated_quizzes[nb_for_train:nb_to_create], for_train=False
     )
 
-    ######################################################################
-    # save the log probas
-
-    file_name = os.path.join(
-        args.result_dir, f"culture_c_quiz_all_{n_epoch:04d}_logp.dat"
-    )
-
-    with open(file_name, "w") as logp_file:
-        for _, ll in recorded_quizzes_logprobas:
-            for l in ll:
-                s = " ".join([str(x.item()) for x in l])
-                logp_file.write(s + "\n")
-
     ######################################################################
     # save images with their logprobas
 
@@ -454,12 +457,12 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
-
-        file_name = os.path.join(args.result_dir, prefix + "_logp.dat")
-        with open(file_name, "w") as logp_file:
-            for l in vl:
-                s = " ".join([str(x.item()) for x in l])
-                logp_file.write(s + "\n")
+        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
+        torch.save(vl, filename)
+        with open(file_name, "w") as logp_file:
+        # for l in vl:
+        # s = " ".join([str(x.item()) for x in l])
+        # logp_file.write(s + "\n")
 
         quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
 
@@ -574,11 +577,14 @@ if args.max_percents_of_test_in_train >= 0:
 
 ######################################################################
 
-nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+if args.nb_new_c_quizzes_for_train is None:
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+
+if args.nb_new_c_quizzes_for_test is None:
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
 
 log_string(
-    f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+    f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
 )
 
 ######################################################################
@@ -586,12 +592,8 @@ log_string(
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
     args.nb_gpts = 2
-    nb_new_c_quizzes_for_train = 100
-    nb_new_c_quizzes_for_test = 10
-
-    def standard_validity(logproba):
-        l = logproba.sort(dim=-1).values
-        return l[:, 0] < math.log(0.5)
+    args.nb_new_c_quizzes_for_train = 100
+    args.nb_new_c_quizzes_for_test = 10
 
 
 ######################################################################
@@ -602,6 +604,26 @@ for n_epoch in range(args.nb_epochs):
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
+    ##################################################
+    # If all the models are good enough, generate new quizzes and
+    # re-compute the test errors
+
+    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+        create_c_quizzes(
+            models,
+            quiz_machine,
+            nb_for_train=args.nb_new_c_quizzes_for_train,
+            nb_for_test=args.nb_new_c_quizzes_for_test,
+        )
+
+        filename = "c_quizzes.pth"
+        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
+        log_string(f"wrote {filename}")
+
+        # Force one epoch of training
+        for model in models:
+            model.main_test_accuracy = 0.0
+
     ##################################################
     # Select, improve, and eval the worst model
 
@@ -640,20 +662,5 @@ for n_epoch in range(args.nb_epochs):
     for model in weakest_models:
         quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
 
-    ##################################################
-    # If all the models are good enough, generate new quizzes and
-    # re-compute the test errors
-
-    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
-            models,
-            quiz_machine,
-            nb_for_train=nb_new_c_quizzes_for_train,
-            nb_for_test=nb_new_c_quizzes_for_test,
-        )
-
-        filename = "c_quizzes.pth"
-        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
-        log_string(f"wrote {filename}")
 
 ######################################################################