From eaed6307836d88abe7c0f4be733a38364ba20e2f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 2 Jul 2024 22:23:59 +0300 Subject: [PATCH] Update. --- main.py | 20 +++++++++++++------- quizz_machine.py | 4 ++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 918f75d..5484f39 100755 --- a/main.py +++ b/main.py @@ -79,23 +79,23 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument("--both_directions", action="store_true", default=False) +parser.add_argument("--bidirectional_validation", action="store_true", default=False) parser.add_argument("--problem", type=str, default="sky") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=4) +parser.add_argument("--min_to_validate", type=int, default=None) -parser.add_argument("--max_to_validate", type=int, default=4) +parser.add_argument("--max_to_validate", type=int, default=None) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--dirty_debug", action="store_true", default=False) -parser.add_argument("--generation_temperature", type=float, default=1.0) +parser.add_argument("--generation_temperature", type=float, default=2.0) -parser.add_argument("--stochastic_validation", action="store_true", default=False) +parser.add_argument("--deterministic_validation", action="store_true", default=False) ###################################################################### @@ -113,6 +113,12 @@ parser.add_argument("--sky_speed", type=int, default=3) args = parser.parse_args() +if args.min_to_validate is None: + args.min_to_validate = args = nb_gpts - 1 + +if args.max_to_validate is None: + args.max_to_validate = args = nb_gpts - 1 + if args.result_dir is None: args.result_dir = f"results_culture" @@ -423,8 +429,8 @@ def create_c_quizzes( nb_correct, seq_logproba = quizz_machine.compute_correctness( c_quizzes, models, - both_directions=args.both_directions, - deterministic_validation=not args.stochastic_validation, + bidirectional_validation=args.bidirectional_validation, + deterministic_validation=args.deterministic_validation, ) for n, l in zip(nb_correct, seq_logproba): diff --git a/quizz_machine.py b/quizz_machine.py index 9b64941..de85520 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -325,7 +325,7 @@ class QuizzMachine: self, c_quizzes, models_for_validation, - both_directions=False, + bidirectional_validation=False, deterministic_validation=True, ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -360,7 +360,7 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - if both_directions: + if bidirectional_validation: reversed_result = reversed_c_quizzes.clone() masked_inplace_autoregression( -- 2.20.1