- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis=False,
- progress_bar_desc=None,
- device=self.device,
- )
+ dispatch = torch.randint(len(other_models), (result.size(0),))
+
+ for n, m in enumerate(other_models):
+ masked_inplace_autoregression(
+ m,
+ self.batch_size,
+ result[dispatch == n],
+ ar_mask[dispatch == n],
+ deterministic_synthesis=False,
+ progress_bar_desc=None,
+ device=self.device,
+ )