From 09c5eea203d5a2d8b1da84db0a336de151cf1c89 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 2 Jul 2024 12:50:37 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 5f19998..1a20563 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -395,10 +395,15 @@ class QuizzMachine: nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) - ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1 - ar_mask_solve = 1 - ar_mask_prompt - seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) + c_quizzes[:, 0] = self.token_forward + + ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device) + ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1 + ar_mask_second = 1 - ar_mask_first + ar_mask_first[:, 0] = 0 + ar_mask_second[:, 0] = 0 + + seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device) if reverse_cleanup: warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) @@ -416,7 +421,7 @@ class QuizzMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_prompt, + ar_mask=ar_mask_first, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, @@ -431,7 +436,7 @@ class QuizzMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_solve, + ar_mask=ar_mask_second, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=True, @@ -440,11 +445,12 @@ class QuizzMachine: if reverse_cleanup: c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_solve, + ar_mask=ar_mask_second, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=True, @@ -452,11 +458,12 @@ class QuizzMachine: ) c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask_solve, + ar_mask=ar_mask_second, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=True, -- 2.39.5