X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=0827a446d0606509f0b0b8fa23f90bfdd0e2ab63;hb=cd5e4647e105a10012d687169d49bec0343e274f;hp=a97ec2e4c8298dcd764480e114fee099979a9d12;hpb=4540ede418ea744e50e0ff0b3a90785015da962b;p=picoclvr.git diff --git a/tasks.py b/tasks.py index a97ec2e..0827a44 100755 --- a/tasks.py +++ b/tasks.py @@ -34,7 +34,7 @@ def masked_inplace_autoregression( batches, dynamic_ncols=True, desc=progress_bar_desc, - # total=input.size(0) // batch_size, + total=(input.size(0) + batch_size - 1) // batch_size, ) with torch.autograd.no_grad(): @@ -1070,6 +1070,7 @@ class RPL(Task): train_sequences = [ rpl.generate( nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, max_input=max_input, prog_len=prog_len, nb_runs=nb_runs, @@ -1080,6 +1081,7 @@ class RPL(Task): test_sequences = [ rpl.generate( nb_starting_values=nb_starting_values, + nb_result_values_max=4 * nb_starting_values, max_input=max_input, prog_len=prog_len, nb_runs=nb_runs,