From: François Fleuret Date: Tue, 9 Jul 2024 22:39:11 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0f8f74f917030ab5216f7b906cf5aae644ec372e;p=culture.git Update. --- diff --git a/main.py b/main.py index 57f79a3..634363f 100755 --- a/main.py +++ b/main.py @@ -392,7 +392,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis): def standard_validity(logproba): l = logproba.sort(dim=-1).values - return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95)) + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) + # warnings.warn("TEST!!!", RuntimeWarning) + # print(l.exp()) + # return (l[:, 0] < math.log(0.99)) def valid_c_quizzes(recorded, criteria): @@ -482,8 +485,8 @@ def create_c_quizzes_( # ------------------------------------------------------------ - standard_validity = lambda nb_correct: torch.logical_and( - nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate + standard_validity = lambda nb_correct: (nb_correct >= args.min_to_validate) & ( + nb_correct <= args.max_to_validate ) file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") diff --git a/quiz_machine.py b/quiz_machine.py index c1477c9..0ae68d0 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -429,7 +429,7 @@ class QuizMachine: F.cross_entropy(output.transpose(1, 2), input, reduction="none") * ar_mask ) - l[:, model.id] = ce.sum(dim=-1) + l[:, model.id] = -ce.sum(dim=-1) return logproba