X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=889d4a9602029153c40937af55102b51be58a972;hb=92c05a4b88f5b0de6a84f4319bb18e2687a0fc2f;hp=a97ec2e4c8298dcd764480e114fee099979a9d12;hpb=4540ede418ea744e50e0ff0b3a90785015da962b;p=picoclvr.git diff --git a/tasks.py b/tasks.py index a97ec2e..889d4a9 100755 --- a/tasks.py +++ b/tasks.py @@ -34,7 +34,7 @@ def masked_inplace_autoregression( batches, dynamic_ncols=True, desc=progress_bar_desc, - # total=input.size(0) // batch_size, + total=(input.size(0) + batch_size - 1) // batch_size, ) with torch.autograd.no_grad():