- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
-
- sum_logits = masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
-
- # Should not be necessary though, the autoregression is done
- # in eval mode
- sum_logits = sum_logits.detach()