nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
- ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
- ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
- ar_mask_solve = 1 - ar_mask_prompt
- seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+ c_quizzes[:, 0] = self.token_forward
+
+ ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
+ ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
+ ar_mask_second = 1 - ar_mask_first
+ ar_mask_first[:, 0] = 0
+ ar_mask_second[:, 0] = 0
+
+ seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
if reverse_cleanup:
warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=ar_mask_prompt,
+ ar_mask=ar_mask_first,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=ar_mask_solve,
+ ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=True,
if reverse_cleanup:
c_quizzes = self.reverse_time(c_quizzes)
+
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=ar_mask_solve,
+ ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=True,
)
c_quizzes = self.reverse_time(c_quizzes)
+
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=ar_mask_solve,
+ ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=True,