X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=0d6d8f57cba918235f951333a64cb1a4c44133d2;hb=d283cd3d46a6323fec4c6a0970ac71e553e4a486;hp=5f199988b3a92d433c00f8b57f5ddf1863ec3019;hpb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 5f19998..0d6d8f5 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -333,7 +333,7 @@ class QuizzMachine: ) def compute_correctness( - self, c_quizzes, models_for_validation, both_directions=True + self, c_quizzes, models_for_validation, both_directions=False ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -390,77 +390,65 @@ class QuizzMachine: ############################################################### - def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False): + def generate_quizzes(self, nb, model_for_generation): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) - ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1 - ar_mask_solve = 1 - ar_mask_prompt - seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) + ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) + ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 + ar_mask_second = 1 - ar_mask_first + ar_mask_first[:, 0] = 0 + ar_mask_second[:, 0] = 0 - if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10.0 - else: - temperature = 1.0 + seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) + + temperature = 10.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_prompt, + ar_mask=ar_mask_first, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_solve, + ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=temperature, + temperature=1.0, deterministic_synthesis=True, device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_solve, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + # Then we return the quizz, and re-generate the response, now + # deterministically - c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_solve, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + c_quizzes = self.reverse_time(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) return c_quizzes, seq_logproba.mean()