- input = self.test_w_quizzes[:96]
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-