Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 10:54:28 +0000 (13:54 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 10:54:28 +0000 (13:54 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index 6137834..10c7b49 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -85,10 +85,6 @@ parser.add_argument("--problem", type=str, default="sky")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--nb_models_for_generation", type=int, default=1)
-
-parser.add_argument("--generation_mode", type=str, default="groupthink")
-
 parser.add_argument("--min_to_validate", type=int, default=4)
 
 parser.add_argument("--max_to_validate", type=int, default=4)
@@ -421,7 +417,7 @@ def create_c_quizzes(
         sum_logits += c_quizzes.size(0) * ave_seq_logproba
         sum_nb_c_quizzes += c_quizzes.size(0)
 
-        nb_correct = quizz_machine.comput_correctness(c_quizzes, models)
+        nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
 
         if args.dirty_debug:
             nb_correct = torch.randint(
@@ -435,7 +431,9 @@ def create_c_quizzes(
 
         nb_validated = valid_c_quizzes(recorded, standard_validity).size(0)
 
-        log_string(f"keep c_quizzes kept {nv} total {nb_validated} / {nb_to_create}")
+        log_string(
+            f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
+        )
 
     # ------------------------------------------------------------
 
@@ -510,6 +508,9 @@ for n_epoch in range(args.nb_epochs):
         f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
+    cta = " ".join([f"{float(m.main_test_accuracy):.02f}" for m in models])
+    log_string(f"current_test_accuracies {cta}")
+
     # replace a fraction of the w_quizzes with a fresh ones
     quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
 
index 8dc23a5..88f2c1c 100755 (executable)
@@ -286,7 +286,7 @@ class QuizzMachine:
 
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
-    def comput_correctness(self, c_quizzes, models_for_validation):
+    def compute_correctness(self, c_quizzes, models_for_validation):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)