From 58fc9b80c89dd8b6c7c2e9051b0dcb0fd2291231 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 09:31:27 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 126 +++--------------------------------------------- 1 file changed, 6 insertions(+), 120 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index cf70b91..134bf21 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -180,9 +180,9 @@ class QuizMachine: ###################################################################### - def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): + def make_ar_mask(self, quizzes, struct, mask): assert struct in self.train_struct - return self.problem.make_ar_mask(quizzes, struct, mask) + return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) def predict(self, model, quizzes, struct, mask): ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) @@ -216,6 +216,7 @@ class QuizMachine: nb = 0 + # We consider all the configurations that we train for for struct, mask in [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), @@ -228,11 +229,9 @@ class QuizMachine: result[i], correct[i] = self.predict( model=model, quizzes=input[i], struct=struct, mask=mask ) - predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :] - correct[i] = (2 * correct[i] - 1) * ( - predicted_parts[i].sum(dim=-1) == 1 - ).long() + solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 + correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() assert nb == input.size(0) @@ -270,7 +269,6 @@ class QuizMachine: def randomize_configuations_inplace(self, quizzes, structs): r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) - for c in range(len(structs)): quizzes[r == c] = self.problem.reconfigure( quizzes[r == c], struct=structs[c] @@ -353,7 +351,7 @@ class QuizMachine: seq_logproba.split(self.batch_size), ): input = input.to(device) - ar_mask = self.make_ar_mask(input, struct, mask) + ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( @@ -366,118 +364,6 @@ class QuizMachine: return seq_logproba.to("cpu") - ############################################################### - - def models_successes(self, models_for_validation, c_quizzes, device=None): - if device is None: - device = self.device - - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - - seq_logproba = torch.zeros( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=device, - ) - - correctly_solved = torch.empty( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=device, - dtype=torch.int64, - ) - - seq_logproba[...] = 0.0 - - c_quizzes = c_quizzes.to(device) - - reversed_c_quizzes = self.problem.reconfigure( - c_quizzes, ("f_A", "A", "f_B", "B") - ) - - for model in models_for_validation: - # A, f(A), B | f(B) - result = c_quizzes.clone() - - ar_mask = self.make_ar_mask( - result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) - ) - - self.autoregression( - model=model, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - ) - - correct = (c_quizzes == result).long().min(dim=-1).values - - # ------------------------------- - - # f(A), A, f(B) | B - result = reversed_c_quizzes.clone() - - ar_mask = self.make_ar_mask( - result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1) - ) - - self.autoregression( - model=model, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - ) - - correct *= (reversed_c_quizzes == result).long().min(dim=-1).values - - # ------------------------------- - - correctly_solved[:, model.id] = correct - - return correctly_solved.to("cpu") - - ############################################################### - - def optimize_quizzes( - self, models, quiz, nb_variants, nb_iterations, struct, mask, proba_understands - ): - for _ in range(nb_iterations): - candidates = quiz[None, :].expand(nb_variants, -1).clone() - r = torch.rand(candidates.size(), device=candidates.device) - u = r.reshape(r.size(0), 4, candidates.size(1) // 4) - # Only change the part indicated by the mask and do not - # touch the special tokens - u[:, :, 0] = 0 - u = u * (1 - torch.tensor(mask, device=u.device)[None, :, None]) - random_mask = F.one_hot(r.argmax(dim=1), num_classes=r.size(1)) - # Keep the first unchanged - random_mask[0, :] = 0 - # Reshape without the 4 parts - candidates.reshape(-1, candidates.size(-1)) - random_mask.reshape(candidates.size()) - random_tokens = torch.randint( - self.problem.nb_token_values - 4, - random_mask.size(), - device=candidates.device, - ) - # Apply the noise - candidates = (1 - random_mask) * candidates + random_mask * random_tokens - seq_logproba = self.models_logprobas( - models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) + self.models_logprobas( - models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) - ) - sorted_logprobas = seq_logproba.sort(dim=1).values.exp() - lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1] - score = second_lowest - lowest - - # score = score * (second_lowest > proba_understands) - - quiz = candidates[score.argmax()] - print(score.max()) - - return quiz.to("cpu") - ###################################################################### def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): -- 2.20.1