X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quizz_machine.py;h=0d6d8f57cba918235f951333a64cb1a4c44133d2;hb=d283cd3d46a6323fec4c6a0970ac71e553e4a486;hp=1a205633ca0b72961cf22bfff0318288d84a78ae;hpb=09c5eea203d5a2d8b1da84db0a336de151cf1c89;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 1a20563..0d6d8f5 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -333,7 +333,7 @@ class QuizzMachine: ) def compute_correctness( - self, c_quizzes, models_for_validation, both_directions=True + self, c_quizzes, models_for_validation, both_directions=False ): reversed_c_quizzes = self.reverse_time(c_quizzes) @@ -390,13 +390,11 @@ class QuizzMachine: ############################################################### - def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False): + def generate_quizzes(self, nb, model_for_generation): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - c_quizzes[:, 0] = self.token_forward - ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 ar_mask_second = 1 - ar_mask_first @@ -405,17 +403,11 @@ class QuizzMachine: seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) - if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10.0 - else: - temperature = 1.0 + temperature = 10.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, @@ -428,46 +420,35 @@ class QuizzMachine: device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, ar_mask=ar_mask_second, seq_logproba=seq_logproba, - temperature=temperature, + temperature=1.0, deterministic_synthesis=True, device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + # Then we return the quizz, and re-generate the response, now + # deterministically - c_quizzes = self.reverse_time(c_quizzes) + c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) return c_quizzes, seq_logproba.mean()