From 132c32eff580b32f1ed207dcc01d353a837dd7b5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 18 Jul 2024 23:56:57 +0200 Subject: [PATCH] Update. --- main.py | 9 +++++++-- quiz_machine.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index bdfe9fd..0d0d373 100755 --- a/main.py +++ b/main.py @@ -90,7 +90,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=2.5) +parser.add_argument("--generation_temperature", type=float, default=1.5) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -419,7 +419,12 @@ def keep_good_quizzes(models, quizzes): ) elif args.c_quiz_validation_mode == "predict": - to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1) + nc = quiz_machine.solution_nb_correct(models, quizzes) + count_nc = tuple( + n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) + ) + log_string(f"nb_correct {count_nc}") + to_keep = nc == (len(models) - 1) else: raise ValueError(f"{args.c_quiz_validation_mode=}") diff --git a/quiz_machine.py b/quiz_machine.py index bbd1b7b..032305a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -611,7 +611,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1, + temperature=0.75, deterministic_synthesis=False, device=self.device, ) @@ -624,7 +624,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1, + temperature=0.75, deterministic_synthesis=False, device=self.device, ) -- 2.39.5