From cb6ac01c534a1ea16de22e44fac1a6e4a85ed29c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 28 Jul 2024 13:43:26 +0200 Subject: [PATCH] Update. --- main.py | 17 +++++++++++++---- quiz_machine.py | 41 +++++++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index 83fd8b8..ca84d3a 100755 --- a/main.py +++ b/main.py @@ -651,16 +651,25 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: + vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B")) number_correct_responses = 0 for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"): number_correct_responses += quiz_machine.models_successes(models, vq) + seq_logproba = quiz_machine.models_logprobas(models, vq) + comments = [] - for r in number_correct_responses: - comments.append("nb_correct " + " ".join([str(n.item()) for n in r])) - vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B")) + for l, r in zip(seq_logproba, number_correct_responses): + comments.append( + "nb_correct " + + " ".join([str(n.item()) for n in r]) + + "\n" + + "proba " + + " ".join([str(x.item()) for x in l]) + ) + filename = f"culture_c_quiz_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, vq, comments=comments @@ -906,7 +915,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): model.main_test_accuracy = 0.0 ################################################## - # Select, improve, and eval the worst model + # Select, improve, and eval the worst model(s) ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) diff --git a/quiz_machine.py b/quiz_machine.py index ba3387c..5dec85c 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -335,56 +335,65 @@ class QuizMachine: ###################################################################### - def solution_token_logprobas(self, models, c_quizzes): - logproba = c_quizzes.new_zeros( + def models_logprobas(self, models_for_validation, c_quizzes, device=None): + if device is None: + device = self.device + + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) + + seq_logproba = torch.zeros( c_quizzes.size(0), - len(models), - c_quizzes.size(1), - device=self.device, - dtype=torch.float32, + max([m.id for m in models_for_validation]) + 1, + device=device, ) - for model in models: + for model in models_for_validation: with torch.autograd.no_grad(): t = model.training model.eval() for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + c_quizzes.split(self.batch_size), + seq_logproba.split(self.batch_size), ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123") + input = input.to(device) + ar_mask = self.make_ar_mask(input) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( output.transpose(1, 2), input, reduction="none" ) * ar_mask - ) + ).sum() model.train(t) - return logproba.to("cpu") + return seq_logproba.to("cpu") ############################################################### - def models_successes(self, models_for_validation, c_quizzes): + def models_successes(self, models_for_validation, c_quizzes, device=None): + if device is None: + device = self.device + + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) + seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, - device=self.device, + device=device, ) correctly_solved = torch.empty( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, - device=self.device, + device=device, dtype=torch.int64, ) seq_logproba[...] = 0.0 - c_quizzes = c_quizzes.to(self.device) + c_quizzes = c_quizzes.to(device) reversed_c_quizzes = self.problem.reconfigure( c_quizzes, ("f_A", "A", "f_B", "B") -- 2.39.5