Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 19:52:44 +0000 (22:52 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 19:52:44 +0000 (22:52 +0300)
main.py
quizz_machine.py

diff --git a/main.py b/main.py
index 0a7be99..714327d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -81,6 +81,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--reverse_cleanup", action="store_true", default=False)
 
+parser.add_argument("--validation_forward_only", action="store_true", default=False)
+
 parser.add_argument("--problem", type=str, default="sky")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
@@ -418,7 +420,7 @@ def create_c_quizzes(
         sum_nb_c_quizzes += c_quizzes.size(0)
 
         nb_correct = quizz_machine.compute_correctness(
-            c_quizzes, models, both_direction=True
+            c_quizzes, models, both_directions=not args.validation_forward_only
         )
 
         if args.dirty_debug:
@@ -513,7 +515,7 @@ for n_epoch in range(args.nb_epochs):
         f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
     )
 
-    cta = " ".join([f"{float(m.main_test_accuracy):.02f}" for m in models])
+    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
     # replace a fraction of the w_quizzes with a fresh ones
index 198d279..4e7576e 100755 (executable)
@@ -287,7 +287,7 @@ class QuizzMachine:
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
     def compute_correctness(
-        self, c_quizzes, models_for_validation, both_direction=True
+        self, c_quizzes, models_for_validation, both_directions=True
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
@@ -315,7 +315,7 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            if both_direction:
+            if both_directions:
                 reversed_result = reversed_c_quizzes.clone()
 
                 masked_inplace_autoregression(