From 0885300179bb93247fd702b9c3f980162190e7e5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 Jul 2024 21:47:10 +0200 Subject: [PATCH] Update. --- main.py | 31 ++++++++++++++++++++----------- quiz_machine.py | 39 +++++---------------------------------- 2 files changed, 25 insertions(+), 45 deletions(-) diff --git a/main.py b/main.py index ff36e98..9d36aba 100755 --- a/main.py +++ b/main.py @@ -90,6 +90,8 @@ parser.add_argument("--proba_not_understands", type=float, default=0.5) parser.add_argument("--generation_temperature", type=float, default=2) +parser.add_argument("--c_quiz_validation_mode", type=str, default="proba") + parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -280,7 +282,8 @@ elif args.problem == "grids": else: raise ValueError -problem.save_some_examples(args.result_dir) +if not args.resume: + problem.save_some_examples(args.result_dir) quiz_machine = quiz_machine.QuizMachine( problem=problem, @@ -371,13 +374,20 @@ def one_epoch(model, quiz_machine, local_device=main_device): def keep_good_quizzes(models, quizzes): quizzes = quizzes[quiz_machine.non_trivial(quizzes)] - token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes) - l = token_logprobas.sum(dim=-1).sort(dim=-1).values + if args.c_quiz_validation_mode == "proba": + token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes) + l = token_logprobas.sum(dim=-1).sort(dim=-1).values - to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & ( - l[:, 1] > math.log(args.proba_understands) - ) + to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & ( + l[:, 1] > math.log(args.proba_understands) + ) + + elif args.c_quiz_validation_mode == "predict": + to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1) + + else: + raise ValueError(f"{args.c_quiz_validation_mode=}") if args.dirty_debug: # warnings.warn("DEBUG", RuntimeWarning) @@ -417,12 +427,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if nb_validated > 0 and nb_validated < nb_to_create: d = (nb_to_create - nb_validated) * duration / nb_validated + e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( + "%a %H:%M" + ) else: - d = 0 - - e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( - "%a %H:%M" - ) + e = "???" log_string( f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})" diff --git a/quiz_machine.py b/quiz_machine.py index c2d1ec3..f66258a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -486,16 +486,12 @@ class QuizMachine: ############################################################### - def compute_correctness( + def solution_nb_correct( self, - c_quizzes, models_for_validation, - bidirectional_validation=False, - deterministic_validation=True, + c_quizzes, + deterministic_validation=False, ): - if bidirectional_validation: - backward_c_quizzes = self.forward_to_backward(c_quizzes) - seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, @@ -507,6 +503,7 @@ class QuizMachine: seq_logproba[...] = 0.0 for model in models_for_validation: + c_quizzes = c_quizzes.to(self.device) result = c_quizzes.clone() ar_mask = self.make_ar_mask(result) @@ -519,40 +516,14 @@ class QuizMachine: seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving c_quizzes", device=self.device, ) correct = (c_quizzes == result).long().min(dim=-1).values - if bidirectional_validation: - backward_result = backward_c_quizzes.clone() - - ar_mask = self.make_ar_mask(backward_result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=backward_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving backward c_quizzes", - device=self.device, - ) - - backward_correct = ( - (backward_c_quizzes == backward_result).long().min(dim=-1).values - ) - - correct *= backward_correct - - # endif - nb_correct += correct - return nb_correct, seq_logproba + return nb_correct.to("cpu") ############################################################### -- 2.39.5