From fbc162363d4a06c9f491e4ce564c659d96a63568 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jul 2024 08:33:59 +0200 Subject: [PATCH] Update. --- main.py | 15 +++- quiz_machine.py | 188 ++++++++++++++++++++++++++++++------------------ 2 files changed, 131 insertions(+), 72 deletions(-) diff --git a/main.py b/main.py index 848ac9c..be5e1bd 100755 --- a/main.py +++ b/main.py @@ -526,6 +526,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) + to_recycle = None + while nb_validated_per_model.sum() < nb_to_validate: # We use the model that has generated the fewest quizzes to # balance the number of quizzes per model overall @@ -542,6 +544,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 model_for_generation=model_for_generation, temperature_hot=args.temperature_hot, temperature_cold=args.temperature_cold, + to_recycle=to_recycle, ) # We discard the trivial ones, according to a criterion @@ -561,6 +564,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 number_correct_responses = 0 nb_remaining = [c_quizzes.size(0)] + rejected = [] for r in range(args.nb_rounds): if c_quizzes.size(0) == 0: @@ -577,11 +581,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 & (nb_sure_fail <= args.max_fail_to_validate) ) + if not to_keep.all(): + rejected.append(c_quizzes[to_keep == False]) + c_quizzes = c_quizzes[to_keep] number_correct_responses = number_correct_responses[to_keep] nb_remaining.append(c_quizzes.size(0)) + to_recycle = torch.cat(rejected, dim=0) if len(rejected) > 0 else None + if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) recorded_validated.append(c_quizzes) @@ -606,10 +615,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" v = " ".join([str(n) for n in nb_remaining]) - log_string(f"filter c_quizzes {v}") log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h) filtering {v}" ) validated_quizzes = torch.cat(recorded_validated, dim=0) @@ -630,7 +638,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if vq.size(0) > 0: number_correct_responses = 0 - for r in range(10): + + for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"): number_correct_responses += quiz_machine.models_successes(models, vq) comments = [] diff --git a/quiz_machine.py b/quiz_machine.py index 7516aed..13c157e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -62,52 +62,6 @@ def one_batch_masked_inplace_autoregression( input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - seq_logproba, - logit_transformer=None, - deterministic_synthesis=False, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc=None, - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(batch_size), - ar_mask.split(batch_size), - seq_logproba.split(batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - logit_transformer=logit_transformer, - deterministic_synthesis=deterministic_synthesis, - ) - - model.train(t) - - ###################################################################### @@ -147,6 +101,51 @@ class QuizMachine: ###################################################################### + def autoregression( + model, + input, + ar_mask, + seq_logproba=None, + logit_transformer=None, + progress_bar_desc=None, + ): + assert input.size() == ar_mask.size() + + if seq_logproba is None: + seq_logproba = torch.empty(input.size(0), device=self.device) + + batches = zip( + input.split(self.batch_size), + ar_mask.split(self.batch_size), + seq_logproba.split(self.batch_size), + ) + + if progress_bar_desc is not None: + batches = tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=(input.size(0) + self.batch_size - 1) // self.batch_size, + ) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, ar_mask, seq_logproba in batches: + one_batch_masked_inplace_autoregression( + model=model, + input=input, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + logit_transformer=logit_transformer, + deterministic_synthesis=deterministic_synthesis, + ) + + model.train(t) + + ###################################################################### + def data_input(self, model, split="train"): assert split in {"train", "test"} @@ -194,16 +193,12 @@ class QuizMachine: ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) - seq_logproba = torch.empty(quizzes.size(0), device=self.device) - - masked_inplace_autoregression( + self.autoregression( model=model, - batch_size=self.batch_size, input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, progress_bar_desc="accuracy", - device=self.device, ) correct = (result == quizzes).min(dim=1).values.long() @@ -400,13 +395,11 @@ class QuizMachine: result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) ) - masked_inplace_autoregression( + self.autoregression( model=model, - batch_size=self.batch_size, input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - device=self.device, ) correct = (c_quizzes == result).long().min(dim=-1).values @@ -420,13 +413,11 @@ class QuizMachine: result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1) ) - masked_inplace_autoregression( + self.autoregression( model=model, - batch_size=self.batch_size, input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - device=self.device, ) correct *= (reversed_c_quizzes == result).long().min(dim=-1).values @@ -445,6 +436,7 @@ class QuizMachine: model_for_generation, temperature_hot=1.0, temperature_cold=1.0, + to_recycle=None, ): c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")) c_quizzes = c_quizzes.to(self.device) @@ -454,42 +446,41 @@ class QuizMachine: lt_noisy = lambda s, logits: logits / temperature_hot lt_clean = lambda s, logits: logits / temperature_cold - masked_inplace_autoregression( + self.autoregression( model=model_for_generation, - batch_size=self.batch_size, input=c_quizzes, ar_mask=self.make_ar_mask( c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0) ), seq_logproba=seq_logproba, logit_transformer=lt_noisy, - device=self.device, ) - masked_inplace_autoregression( + if to_recycle is not None: + l = c_quizzes.size(1) // 4 + self.logger(f"recycling {to_recycle.size(0)} rejected quizzes") + c_quizzes[: to_recycle.size(0), :l] = to_recycle[:, 3 * l :] + + self.autoregression( model=model_for_generation, - batch_size=self.batch_size, input=c_quizzes, ar_mask=self.make_ar_mask( c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1) ), seq_logproba=seq_logproba, logit_transformer=lt_clean, - device=self.device, ) c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - masked_inplace_autoregression( + self.autoregression( model=model_for_generation, - batch_size=self.batch_size, input=c_quizzes, ar_mask=self.make_ar_mask( c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) ), seq_logproba=seq_logproba, logit_transformer=lt_clean, - device=self.device, ) return c_quizzes.to("cpu") @@ -514,16 +505,75 @@ class QuizMachine: lt_noisy = lambda s, logits: logits / temperature_hot - masked_inplace_autoregression( + self.autoregression( model=model_for_generation, - batch_size=self.batch_size, input=c_quizzes, ar_mask=self.make_ar_mask( c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1) ), seq_logproba=seq_logproba, logit_transformer=lt_noisy, - device=self.device, + ) + + return c_quizzes.to("cpu") + + ###################################################################### + + def generate_c_quizzes_2( + self, + nb, + model_for_generation, + temperature_hot=1.0, + temperature_cold=1.0, + ): + warnings.warn( + "**************************** simple quiz generation", RuntimeWarning + ) + + seq_logproba = torch.zeros(nb, device=self.device) + + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) + c_quizzes = c_quizzes.to(self.device) + + self.autoregression( + model=model_for_generation, + input=c_quizzes, + ar_mask=self.make_ar_mask( + c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 0, 0) + ), + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + ) + + c_quizzes2 = self.problem.create_empty_quizzes(nb, ("B", "f_B", "A", "f_A")) + c_quizzes2 = c_quizzes2.to(self.device) + + self.autoregression( + model=model_for_generation, + input=c_quizzes2, + ar_mask=self.make_ar_mask( + c_quizzes2, + ("B", "f_B", "A", "f_A"), + (1, 0, 0, 0), + ), + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + ) + + l = c_quizzes.size(1) // 4 + c_quizzes[:, 2 * l : 3 * l] = c_quizzes2[:, :l] + + self.autoregression( + model=model_for_generation, + input=c_quizzes, + ar_mask=self.make_ar_mask( + c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, ) return c_quizzes.to("cpu") -- 2.39.5