X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=45bddb762721c468d8990de9cc88044c5cbbd635;hb=f23843d33a4fa5a38f5034deab8f473793732ee3;hp=0c2ff24dd8907845418b13c5867ba41469c83529;hpb=3c97745cdf9ae30a87903e3039e38c868e136d6e;p=picoclvr.git diff --git a/main.py b/main.py index 0c2ff24..45bddb7 100755 --- a/main.py +++ b/main.py @@ -187,6 +187,8 @@ def masked_inplace_autoregression( progress_bar_desc="autoregression", device=torch.device("cpu"), ): + # p = logits.softmax(1) + # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2) batches = zip(input.split(batch_size), ar_mask.split(batch_size)) if progress_bar_desc is not None: tqdm.tqdm(