From 98dbe305561906ad65deb5245aa7aeeb7a824fb2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 22:52:44 +0300 Subject: [PATCH] Update. --- main.py | 6 ++++-- quizz_machine.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 0a7be99..714327d 100755 --- a/main.py +++ b/main.py @@ -81,6 +81,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--reverse_cleanup", action="store_true", default=False) +parser.add_argument("--validation_forward_only", action="store_true", default=False) + parser.add_argument("--problem", type=str, default="sky") parser.add_argument("--nb_gpts", type=int, default=5) @@ -418,7 +420,7 @@ def create_c_quizzes( sum_nb_c_quizzes += c_quizzes.size(0) nb_correct = quizz_machine.compute_correctness( - c_quizzes, models, both_direction=True + c_quizzes, models, both_directions=not args.validation_forward_only ) if args.dirty_debug: @@ -513,7 +515,7 @@ for n_epoch in range(args.nb_epochs): f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}" ) - cta = " ".join([f"{float(m.main_test_accuracy):.02f}" for m in models]) + cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") # replace a fraction of the w_quizzes with a fresh ones diff --git a/quizz_machine.py b/quizz_machine.py index 198d279..4e7576e 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -287,7 +287,7 @@ class QuizzMachine: return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1) def compute_correctness( - self, c_quizzes, models_for_validation, both_direction=True + self, c_quizzes, models_for_validation, both_directions=True ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -315,7 +315,7 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - if both_direction: + if both_directions: reversed_result = reversed_c_quizzes.clone() masked_inplace_autoregression( -- 2.39.5