From f3643a28dc33f371235a55bf1bb5d4ba13a36c8d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 08:08:28 +0200 Subject: [PATCH] Update. --- main.py | 33 +++++++++++---- quiz_machine.py | 106 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 35 deletions(-) diff --git a/main.py b/main.py index 5c58beb..b149e62 100755 --- a/main.py +++ b/main.py @@ -88,10 +88,12 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=2) +# parser.add_argument("--generation_temperature", type=float, default=2) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") +parser.add_argument("--forward_only", action="store_true", default=False) + parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -411,10 +413,10 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 while nb_validated < nb_to_create: model_for_generation = models[torch.randint(len(models), (1,))] - c_quizzes = quiz_machine.generate_quizzes( + c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, model_for_generation=model_for_generation, - temperature=args.generation_temperature, + forward_only=args.forward_only, ) c_quizzes = keep_good_quizzes(models, c_quizzes) @@ -482,10 +484,19 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) - quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) - model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) - quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) + quiz_machine.create_w_quizzes( + model=model, + nb=args.nb_train_samples, + for_train=True, + forward_only=args.forward_only, + ) + + quiz_machine.create_w_quizzes( + model=model, + nb=args.nb_test_samples, + for_train=False, + forward_only=args.forward_only, + ) models.append(model) @@ -659,7 +670,11 @@ for n_epoch in range(args.nb_epochs): # Renew the training samples for model in weakest_models: - quiz_machine.renew_w_quizzes(model, args.nb_train_samples) - + quiz_machine.renew_w_quizzes( + model=model, + nb=args.nb_train_samples, + for_train=True, + forward_only=args.forward_only, + ) ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 0f834dc..008e435 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -428,12 +428,26 @@ class QuizMachine: ###################################################################### - def renew_w_quizzes(self, model, nb, for_train=True): + def create_w_quizzes(self, model, nb, for_train=True, forward_only=False): + input = self.generate_token_sequences(nb) + + if not forward_only: + self.reverse_random_half_in_place(input) + + if for_train: + model.train_w_quizzes = input + else: + model.test_w_quizzes = input + + ###################################################################### + + def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False): input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) - self.reverse_random_half_in_place(fresh_w_quizzes) + if not forward_only: + self.reverse_random_half_in_place(fresh_w_quizzes) input[-nb:] = fresh_w_quizzes.to("cpu") ###################################################################### @@ -527,7 +541,7 @@ class QuizMachine: ############################################################### - def generate_quizzes(self, nb, model_for_generation, temperature=1.0): + def generate_c_quizzes(self, nb, model_for_generation, forward_only=False): c_quizzes = torch.empty( nb, self.prompt_len + self.answer_len + 2, @@ -537,29 +551,69 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) - c_quizzes[:, 0] = self.token_forward - c_quizzes[:, 1 + self.prompt_len] = self.token_forward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=False, - device=self.device, - ) + if forward_only: + c_quizzes[:, 0] = self.token_forward + c_quizzes[:, 1 + self.prompt_len] = self.token_forward - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1, - deterministic_synthesis=False, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes, first=True), + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=False, + device=self.device, + ) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes), + seq_logproba=seq_logproba, + temperature=1, + deterministic_synthesis=False, + device=self.device, + ) + + else: + c_quizzes[:, 0] = self.token_backward + c_quizzes[:, 1 + self.answer_len] = self.token_backward + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes, first=True), + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=False, + device=self.device, + ) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes), + seq_logproba=seq_logproba, + temperature=1, + deterministic_synthesis=False, + device=self.device, + ) + + c_quizzes = self.reverse_time(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes), + seq_logproba=seq_logproba, + temperature=1, + deterministic_synthesis=False, + device=self.device, + ) return c_quizzes.to("cpu") -- 2.20.1