progress_bar_desc="autoregression",
device=torch.device("cpu"),
):
progress_bar_desc="autoregression",
device=torch.device("cpu"),
):
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
tqdm.tqdm(
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
tqdm.tqdm(