From: François Fleuret Date: Thu, 18 Jul 2024 21:56:57 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=132c32eff580b32f1ed207dcc01d353a837dd7b5;p=culture.git Update. --- diff --git a/main.py b/main.py index bdfe9fd..0d0d373 100755 --- a/main.py +++ b/main.py @@ -90,7 +90,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=2.5) +parser.add_argument("--generation_temperature", type=float, default=1.5) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -419,7 +419,12 @@ def keep_good_quizzes(models, quizzes): ) elif args.c_quiz_validation_mode == "predict": - to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1) + nc = quiz_machine.solution_nb_correct(models, quizzes) + count_nc = tuple( + n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) + ) + log_string(f"nb_correct {count_nc}") + to_keep = nc == (len(models) - 1) else: raise ValueError(f"{args.c_quiz_validation_mode=}") diff --git a/quiz_machine.py b/quiz_machine.py index bbd1b7b..032305a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -611,7 +611,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1, + temperature=0.75, deterministic_synthesis=False, device=self.device, ) @@ -624,7 +624,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1, + temperature=0.75, deterministic_synthesis=False, device=self.device, )