Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 19:47:10 +0000 (21:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 19:47:10 +0000 (21:47 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index ff36e98..9d36aba 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -90,6 +90,8 @@ parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
 parser.add_argument("--generation_temperature", type=float, default=2)
 
+parser.add_argument("--c_quiz_validation_mode", type=str, default="proba")
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
@@ -280,7 +282,8 @@ elif args.problem == "grids":
 else:
     raise ValueError
 
-problem.save_some_examples(args.result_dir)
+if not args.resume:
+    problem.save_some_examples(args.result_dir)
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
@@ -371,13 +374,20 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 def keep_good_quizzes(models, quizzes):
     quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
-    token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
 
-    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+    if args.c_quiz_validation_mode == "proba":
+        token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
+        l = token_logprobas.sum(dim=-1).sort(dim=-1).values
 
-    to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
-        l[:, 1] > math.log(args.proba_understands)
-    )
+        to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
+            l[:, 1] > math.log(args.proba_understands)
+        )
+
+    elif args.c_quiz_validation_mode == "predict":
+        to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1)
+
+    else:
+        raise ValueError(f"{args.c_quiz_validation_mode=}")
 
     if args.dirty_debug:
         # warnings.warn("DEBUG", RuntimeWarning)
@@ -417,12 +427,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         if nb_validated > 0 and nb_validated < nb_to_create:
             d = (nb_to_create - nb_validated) * duration / nb_validated
+            e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+                "%a %H:%M"
+            )
         else:
-            d = 0
-
-        e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
-            "%a %H:%M"
-        )
+            e = "???"
 
         log_string(
             f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
index c2d1ec3..f66258a 100755 (executable)
@@ -486,16 +486,12 @@ class QuizMachine:
 
     ###############################################################
 
-    def compute_correctness(
+    def solution_nb_correct(
         self,
-        c_quizzes,
         models_for_validation,
-        bidirectional_validation=False,
-        deterministic_validation=True,
+        c_quizzes,
+        deterministic_validation=False,
     ):
-        if bidirectional_validation:
-            backward_c_quizzes = self.forward_to_backward(c_quizzes)
-
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
@@ -507,6 +503,7 @@ class QuizMachine:
         seq_logproba[...] = 0.0
 
         for model in models_for_validation:
+            c_quizzes = c_quizzes.to(self.device)
             result = c_quizzes.clone()
 
             ar_mask = self.make_ar_mask(result)
@@ -519,40 +516,14 @@ class QuizMachine:
                 seq_logproba=seq_logproba[:, model.id],
                 temperature=1.0,
                 deterministic_synthesis=deterministic_validation,
-                # progress_bar_desc="solving c_quizzes",
                 device=self.device,
             )
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            if bidirectional_validation:
-                backward_result = backward_c_quizzes.clone()
-
-                ar_mask = self.make_ar_mask(backward_result)
-
-                masked_inplace_autoregression(
-                    model=model,
-                    batch_size=self.batch_size,
-                    input=backward_result,
-                    ar_mask=ar_mask,
-                    seq_logproba=seq_logproba[:, model.id],
-                    temperature=1.0,
-                    deterministic_synthesis=deterministic_validation,
-                    # progress_bar_desc="solving backward c_quizzes",
-                    device=self.device,
-                )
-
-                backward_correct = (
-                    (backward_c_quizzes == backward_result).long().min(dim=-1).values
-                )
-
-                correct *= backward_correct
-
-            # endif
-
             nb_correct += correct
 
-        return nb_correct, seq_logproba
+        return nb_correct.to("cpu")
 
     ###############################################################