progress_bar_desc="autoregression",
device=torch.device("cpu"),
):
+ # p = logits.softmax(1)
+ # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
tqdm.tqdm(