From: François Fleuret Date: Sun, 28 Jul 2024 15:18:33 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e9605e0b8dcba36f84f63a8c857c5cb0263c6906;p=culture.git Update. --- diff --git a/main.py b/main.py index ca84d3a..8eeb8de 100755 --- a/main.py +++ b/main.py @@ -91,7 +91,7 @@ parser.add_argument("--max_fail_to_validate", type=int, default=1) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) -parser.add_argument("--proba_understands", type=float, default=0.9) +parser.add_argument("--proba_understands", type=float, default=0.99) parser.add_argument("--proba_not_understands", type=float, default=0.5) @@ -577,33 +577,24 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # This is nb_quizzes x nb_models - number_correct_responses = 0 - nb_remaining = [c_quizzes.size(0)] - - for r in range(args.nb_rounds): - if c_quizzes.size(0) == 0: - break - - number_correct_responses += quiz_machine.models_successes(models, c_quizzes) - - nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1) - nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1) - - to_keep = ( - (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1)) - & (nb_sure_fail >= 1) - & (nb_sure_fail <= args.max_fail_to_validate) - ) - - if not to_keep.all(): - rejected.append(c_quizzes[to_keep == False]) + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) - c_quizzes = c_quizzes[to_keep] - number_correct_responses = number_correct_responses[to_keep] + probas = seq_logproba.exp() + nb_sure_correct = (probas >= args.proba_understands).long().sum(dim=1) + nb_sure_fail = (probas <= args.proba_understands).long().sum(dim=1) - nb_remaining.append(c_quizzes.size(0)) + to_keep = ( + (nb_sure_correct + nb_sure_fail == probas.size(1)) + & (nb_sure_fail >= 1) + & (nb_sure_fail <= args.max_fail_to_validate) + ) - to_recycle = torch.cat(rejected, dim=0) if len(rejected) > 0 else None + to_recycle = c_quizzes[to_keep == False] if not to_keep.all() else None + c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) @@ -628,10 +619,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 else: e = "???" - v = " ".join([str(n) for n in nb_remaining]) - log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h) filtering {v}" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) validated_quizzes = torch.cat(recorded_validated, dim=0) @@ -651,24 +640,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: - vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B")) - number_correct_responses = 0 - - for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"): - number_correct_responses += quiz_machine.models_successes(models, vq) - - seq_logproba = quiz_machine.models_logprobas(models, vq) + seq_logproba = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) comments = [] - for l, r in zip(seq_logproba, number_correct_responses): - comments.append( - "nb_correct " - + " ".join([str(n.item()) for n in r]) - + "\n" - + "proba " - + " ".join([str(x.item()) for x in l]) - ) + for l in seq_logproba: + comments.append(+"proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) filename = f"culture_c_quiz_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image( diff --git a/quiz_machine.py b/quiz_machine.py index 5dec85c..f147983 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -335,11 +335,13 @@ class QuizMachine: ###################################################################### - def models_logprobas(self, models_for_validation, c_quizzes, device=None): + def models_logprobas( + self, models_for_validation, c_quizzes, struct, mask, device=None + ): if device is None: device = self.device - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) + c_quizzes = self.problem.reconfigure(c_quizzes, struct) seq_logproba = torch.zeros( c_quizzes.size(0), @@ -357,14 +359,14 @@ class QuizMachine: seq_logproba.split(self.batch_size), ): input = input.to(device) - ar_mask = self.make_ar_mask(input) + ar_mask = self.make_ar_mask(input, struct, mask) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( output.transpose(1, 2), input, reduction="none" ) * ar_mask - ).sum() + ).sum(dim=1) model.train(t) @@ -442,6 +444,25 @@ class QuizMachine: ############################################################### + def optimize_quizzes(self, quizzes, nb_variants, nb_iterations, struct, mask): + for _ in range(nb_iterations): + candidates = quizzes[:, None].expand(-1, nb_variants, -1) + r = torch.rand(candidates.size(), device=candidates.device) + u = r.reshape( + candidates.size(0) * candidates.size(1), 4, candidates.size(2) // 4 + ) + u[:, :, 0] = 0 + u = u * torch.tensor(mask, device=u.device)[None, :, None] + random_mask = (r.sort(dim=0, descending=True).indices == 0).long() + random_mask[:, 0] = 0 + candidates.reshape(-1, candidates.size(-1)) + random_mask.reshape(candidates.size()) + random_tokens = torch.randint( + self.problem.nb_token_values - 4, random_mask.size() + ) + candidates = (1 - random_mask) * candidates + random_mask * random_tokens + ar_mask = (self.make_ar_mask(candidates, struct, make_ar_mask),) + def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): seq_logproba = torch.zeros(nb, device=self.device)