From c272b4d4a39d5572f171a4eadf98975ece8a0eeb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 2 Aug 2024 06:42:32 +0200 Subject: [PATCH] Update. --- main.py | 43 +++++++++++++++++++++++------------- quiz_machine.py | 58 ++++++++++++++++++++++--------------------------- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index 69452e6..059a29d 100755 --- a/main.py +++ b/main.py @@ -104,7 +104,7 @@ parser.add_argument("--temperature_cold", type=float, default=1) parser.add_argument("--prompt_noise", type=float, default=0.0) -parser.add_argument("--nb_averaging_rounds", type=int, default=1) +parser.add_argument("--nb_averaging_rounds", type=int, default=3) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -162,7 +162,7 @@ assert not args.grids_science_tasks or ( default_args = { "model": "37M", "batch_size": 25, - "inference_batch_size": 100, + "inference_batch_size": 50, "nb_train_samples": 100000, "nb_test_samples": 10000, } @@ -345,7 +345,6 @@ quiz_machine = quiz_machine.QuizMachine( batch_size=args.inference_batch_size, result_dir=args.result_dir, prompt_noise=args.prompt_noise, - nb_averaging_rounds=args.nb_averaging_rounds, logger=log_string, device=main_device, ) @@ -581,15 +580,20 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[to_keep] - # This is nb_quizzes x nb_models + probas = 0 - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + for a in range(args.nb_averaging_rounds): + # This is nb_quizzes x nb_models - probas = seq_logproba.exp() + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas += seq_logproba.exp() + + probas /= args.nb_averaging_rounds nb_succeed = (probas >= args.proba_understands).long().sum(dim=1) nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) @@ -650,11 +654,20 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: - seq_logproba = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + probas = 0 + + for a in range(args.nb_averaging_rounds): + # This is nb_quizzes x nb_models + + seq_logproba = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas += seq_logproba.exp() + + probas /= args.nb_averaging_rounds comments = [] diff --git a/quiz_machine.py b/quiz_machine.py index cfab73a..015f6d2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -68,7 +68,6 @@ class QuizMachine: batch_size, result_dir, prompt_noise, - nb_averaging_rounds, logger, device=torch.device("cpu"), ): @@ -80,11 +79,7 @@ class QuizMachine: self.logger = logger self.prompt_len = None self.answer_len = None - - assert prompt_noise > 0 or nb_averaging_rounds == 1 - self.prompt_noise = prompt_noise - self.nb_averaging_rounds = nb_averaging_rounds self.understood_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), @@ -349,34 +344,33 @@ class QuizMachine: device=device, ) - for a in range(self.nb_averaging_rounds): - if self.prompt_noise > 0.0 and noise_mask is not None: - c_quizzes = self.problem.inject_noise( - c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask - ) + if self.prompt_noise > 0.0 and noise_mask is not None: + c_quizzes = self.problem.inject_noise( + c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask + ) - for model in models_for_validation: - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), - seq_logproba.split(self.batch_size), - ): - input = input.to(device) - ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) - output = model(mygpt.BracketedSequence(input)).x - l[:, model.id] += ( - -F.cross_entropy( - output.transpose(1, 2), input, reduction="none" - ) - * ar_mask - ).sum(dim=1) - - model.train(t) - - return seq_logproba.div(self.nb_averaging_rounds).to("cpu") + for model in models_for_validation: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), + seq_logproba.split(self.batch_size), + ): + input = input.to(device) + ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * ar_mask + ).sum(dim=1) + + model.train(t) + + return seq_logproba.to("cpu") ###################################################################### -- 2.39.5