From 0d84e2c1631b6802e07965e1109aa8e47c932824 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 28 Jul 2024 21:06:10 +0200 Subject: [PATCH] Update. --- main.py | 15 ++++++++------- quiz_machine.py | 35 ++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index a165696..0a148b1 100755 --- a/main.py +++ b/main.py @@ -89,7 +89,7 @@ parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--max_fail_to_validate", type=int, default=1) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98) parser.add_argument("--proba_understands", type=float, default=0.99) @@ -584,16 +584,17 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ) probas = seq_logproba.exp() - nb_sure_correct = (probas >= args.proba_understands).long().sum(dim=1) - nb_sure_fail = (probas <= args.proba_understands).long().sum(dim=1) + + nb_succeed = (probas >= args.proba_understands).long().sum(dim=1) + nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) to_keep = ( - (nb_sure_correct + nb_sure_fail == probas.size(1)) - & (nb_sure_fail >= 1) - & (nb_sure_fail <= args.max_fail_to_validate) + (nb_succeed + nb_fail == probas.size(1)) + & (nb_fail >= 1) + & (nb_fail <= args.max_fail_to_validate) ) - to_recycle = c_quizzes[to_keep == False] if not to_keep.all() else None + to_recycle = c_quizzes[to_keep == False] c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: diff --git a/quiz_machine.py b/quiz_machine.py index f147983..34f6b62 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -444,24 +444,40 @@ class QuizMachine: ############################################################### - def optimize_quizzes(self, quizzes, nb_variants, nb_iterations, struct, mask): + def optimize_quizzes(self, quiz, nb_variants, nb_iterations, struct, mask): for _ in range(nb_iterations): - candidates = quizzes[:, None].expand(-1, nb_variants, -1) + candidates = quizzes[None].expand(nb_variants, -1) r = torch.rand(candidates.size(), device=candidates.device) - u = r.reshape( - candidates.size(0) * candidates.size(1), 4, candidates.size(2) // 4 - ) + u = r.reshape(r.size(0), 4, candidates.size(1) // 4) + # Only change the part indicated by the mask and do not + # touch the special tokens u[:, :, 0] = 0 u = u * torch.tensor(mask, device=u.device)[None, :, None] random_mask = (r.sort(dim=0, descending=True).indices == 0).long() - random_mask[:, 0] = 0 + # Keep the first unchanged + random_mask[:, 0, :] = 0 + # Reshape without the 4 parts candidates.reshape(-1, candidates.size(-1)) random_mask.reshape(candidates.size()) random_tokens = torch.randint( self.problem.nb_token_values - 4, random_mask.size() ) + # Apply the noise candidates = (1 - random_mask) * candidates + random_mask * random_tokens - ar_mask = (self.make_ar_mask(candidates, struct, make_ar_mask),) + seq_logproba = quiz_machine.models_logprobas( + models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) + sorted_logprobas = seq_logproba.sort(dim=1).values.exp() + lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1] + score = second_lowest - lowest + + score = score * (second_lowest > args.proba_understands) + + quiz = candidates[score.argmax()] + + return quiz def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): seq_logproba = torch.zeros(nb, device=self.device) @@ -484,10 +500,11 @@ class QuizMachine: logit_transformer=t, ) - if to_recycle is not None: + if to_recycle is not None and to_recycle.size(0) > 0: to_recycle = self.problem.reconfigure(to_recycle, s) c_quizzes[: to_recycle.size(0)] = to_recycle - to_recycle = None + + to_recycle = None c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) -- 2.39.5