X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=b967465de6c23a94b9ee1af8a672337167903ac8;hb=8a548630c88957264306db4354e880414b0fa8ef;hp=cdf8f9e48e583341ab68cdb30e90f5220095dd6e;hpb=17c63771f2ca82ce39d8406e377ace2015fe69fc;p=culture.git diff --git a/tasks.py b/tasks.py index cdf8f9e..b967465 100755 --- a/tasks.py +++ b/tasks.py @@ -276,6 +276,9 @@ class World(Task): average_logits = sum_logits / quizzes.numel() + # It's a bit brutal to do it twice, we should probably have a + # moving average and apply it right away + if desired_average_logits is not None: temperature = average_logits / desired_average_logits masked_inplace_autoregression(