From a8e608a50b84583ad624cdf69d7b34699557235b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 2 Jul 2024 13:07:46 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 50 ++++++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 1a20563..5807b66 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -395,8 +395,6 @@ class QuizzMachine: nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - c_quizzes[:, 0] = self.token_forward - ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 ar_mask_second = 1 - ar_mask_first @@ -406,16 +404,13 @@ class QuizzMachine: seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) if reverse_cleanup: - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) temperature = 10.0 else: temperature = 1.0 - # warnings.warn("noise injection", RuntimeWarning) - # noise_std = torch.rand(1).item() - # self.logger(f"{noise_std=}") + # First, we generate the answer at high temperature - # mygpt.set_noise_injection(model_for_generation, noise_std) + c_quizzes[:, 0] = self.token_backward masked_inplace_autoregression( model=model_for_generation, @@ -428,10 +423,10 @@ class QuizzMachine: device=self.device, ) - # mygpt.set_noise_injection(model_for_generation, 0.0) - ave_seq_logproba = seq_logproba.mean() + # Then, we generate the prompt deterministically + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, @@ -443,31 +438,20 @@ class QuizzMachine: device=self.device, ) - if reverse_cleanup: - c_quizzes = self.reverse_time(c_quizzes) + # Then we return the quizz, and re-generate the response, now + # deterministically - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) - - c_quizzes = self.reverse_time(c_quizzes) + c_quizzes = self.reverse_time(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_second, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_second, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) return c_quizzes, seq_logproba.mean() -- 2.39.5