X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=5807b660c1fb2a3e291d41661b01d94e93a52d1f;hb=a8e608a50b84583ad624cdf69d7b34699557235b;hp=1a205633ca0b72961cf22bfff0318288d84a78ae;hpb=09c5eea203d5a2d8b1da84db0a336de151cf1c89;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 1a20563..5807b66 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -395,8 +395,6 @@ class QuizzMachine: nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - c_quizzes[:, 0] = self.token_forward - ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 ar_mask_second = 1 - ar_mask_first @@ -406,16 +404,13 @@ class QuizzMachine: seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) temperature = 10.0 else: temperature = 1.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, @@ -428,10 +423,10 @@ class QuizzMachine: device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, @@ -443,31 +438,20 @@ class QuizzMachine: device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) + # Then we return the quizz, and re-generate the response, now + # deterministically - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) - - c_quizzes = self.reverse_time(c_quizzes) + c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) return c_quizzes, seq_logproba.mean()