X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=5807b660c1fb2a3e291d41661b01d94e93a52d1f;hb=a8e608a50b84583ad624cdf69d7b34699557235b;hp=5f199988b3a92d433c00f8b57f5ddf1863ec3019;hpb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 5f19998..5807b66 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -395,72 +395,63 @@ class QuizzMachine: nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) - ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1 - ar_mask_solve = 1 - ar_mask_prompt - seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) + ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) + ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 + ar_mask_second = 1 - ar_mask_first + ar_mask_first[:, 0] = 0 + ar_mask_second[:, 0] = 0 + + seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) temperature = 10.0 else: temperature = 1.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_prompt, + ar_mask=ar_mask_first, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_solve, + ar_mask=ar_mask_second, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=True, device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_solve, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + # Then we return the quizz, and re-generate the response, now + # deterministically - c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_solve, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + c_quizzes = self.reverse_time(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) return c_quizzes, seq_logproba.mean()