X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quizz_machine.py;h=470b095ce38ec23d6a034ee3a338d87b7b1e9b52;hb=08d4ba04f038318080fc2815d85843c4873c896f;hp=1a205633ca0b72961cf22bfff0318288d84a78ae;hpb=09c5eea203d5a2d8b1da84db0a336de151cf1c89;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 1a20563..470b095 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -29,8 +29,6 @@ def one_batch_masked_inplace_autoregression( seq_logproba, temperature=1.0, deterministic_synthesis=False, - forbidden_tokens=None, - forced_biases=None, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -45,12 +43,6 @@ def one_batch_masked_inplace_autoregression( logits = (logits / temperature).log_softmax(dim=-1) - if forbidden_tokens is not None: - logits = logits.masked_fill(forbidden_tokens, float("-inf")) - - if forced_biases is not None: - logits = logits + forced_biases[None, :] - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -104,8 +96,6 @@ def masked_inplace_autoregression( seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=deterministic_synthesis, - forbidden_tokens=forbidden_tokens, - forced_biases=logit_biases, ) model.train(t) @@ -170,7 +160,6 @@ class QuizzMachine: ) def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False): - print(f"DEBUG {quizzes.size()=}") l = (quizzes.size(1) - 1) // 2 forward = (quizzes[:, 0] == self.token_forward).long() backward = (quizzes[:, 0] == self.token_backward).long() @@ -333,12 +322,16 @@ class QuizzMachine: ) def compute_correctness( - self, c_quizzes, models_for_validation, both_directions=True + self, c_quizzes, models_for_validation, both_directions=False ): reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) - seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + seq_logproba = torch.zeros( + c_quizzes.size(0), + max([m.id for m in models_for_validation]) + 1, + device=self.device, + ) # Check how many of models can solve the quizzes in both directions @@ -347,12 +340,14 @@ class QuizzMachine: for model in models_for_validation: result = c_quizzes.clone() + seq_logproba[...] = 0.0 + masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving c_quizzes", @@ -369,7 +364,7 @@ class QuizzMachine: batch_size=self.batch_size, input=reversed_result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logproba=seq_logproba[:, model.id], temperature=1.0, deterministic_synthesis=True, # progress_bar_desc="solving reversed c_quizzes", @@ -386,36 +381,28 @@ class QuizzMachine: nb_correct += correct - return nb_correct + return nb_correct, seq_logproba ############################################################### - def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False): + def generate_quizzes(self, nb, model_for_generation): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - c_quizzes[:, 0] = self.token_forward - ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 ar_mask_second = 1 - ar_mask_first ar_mask_first[:, 0] = 0 ar_mask_second[:, 0] = 0 - seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) + seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device) - if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10.0 - else: - temperature = 1.0 + temperature = 10.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, @@ -428,9 +415,7 @@ class QuizzMachine: device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically masked_inplace_autoregression( model=model_for_generation, @@ -438,36 +423,25 @@ class QuizzMachine: input=c_quizzes, ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=temperature, + temperature=1.0, deterministic_synthesis=True, device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) + # Then we return the quizz, and re-generate the response, now + # deterministically - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + c_quizzes = self.reverse_time(c_quizzes) - c_quizzes = self.reverse_time(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) - return c_quizzes, seq_logproba.mean() + return c_quizzes