- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
-
- sum_logits = masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )