From 0f8f74f917030ab5216f7b906cf5aae644ec372e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 10 Jul 2024 00:39:11 +0200 Subject: [PATCH] Update. --- main.py | 9 ++++++--- quiz_machine.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 57f79a3..634363f 100755 --- a/main.py +++ b/main.py @@ -392,7 +392,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis): def standard_validity(logproba): l = logproba.sort(dim=-1).values - return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95)) + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) + # warnings.warn("TEST!!!", RuntimeWarning) + # print(l.exp()) + # return (l[:, 0] < math.log(0.99)) def valid_c_quizzes(recorded, criteria): @@ -482,8 +485,8 @@ def create_c_quizzes_( # ------------------------------------------------------------ - standard_validity = lambda nb_correct: torch.logical_and( - nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate + standard_validity = lambda nb_correct: (nb_correct >= args.min_to_validate) & ( + nb_correct <= args.max_to_validate ) file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") diff --git a/quiz_machine.py b/quiz_machine.py index c1477c9..0ae68d0 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -429,7 +429,7 @@ class QuizMachine: F.cross_entropy(output.transpose(1, 2), input, reduction="none") * ar_mask ) - l[:, model.id] = ce.sum(dim=-1) + l[:, model.id] = -ce.sum(dim=-1) return logproba -- 2.20.1