- for input, ar_mask in tqdm.tqdm(
- zip(input.split(batch_size), ar_mask.split(batch_size)),
- dynamic_ncols=True,
- desc="autoregression",
- total=input.size(0) // batch_size,
- ):
+ batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+ if progress_bar_desc is not None:
+ tqdm.tqdm(
+ batches,
+ dynamic_ncols=True,
+ desc=progress_bar_desc,
+ total=input.size(0) // batch_size,
+ )
+ for input, ar_mask in batches: