X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=tasks.py;h=64fe96763f16033fe1241ef91a0b677d13b96f44;hb=fc1de19bf86b2cfd09264dfc6fbda1937248a40a;hp=cdf8f9e48e583341ab68cdb30e90f5220095dd6e;hpb=17c63771f2ca82ce39d8406e377ace2015fe69fc;p=culture.git diff --git a/tasks.py b/tasks.py index cdf8f9e..64fe967 100755 --- a/tasks.py +++ b/tasks.py @@ -154,6 +154,9 @@ class World(Task): self.nb_batch_samples_world = input.size(0) self.nb_batch_samples_quizzes = 0 + # Shuffle + input = input[torch.randperm(input.size(0))] + if desc is None: desc = f"epoch-{split}" for batch in tqdm.tqdm( @@ -274,8 +277,15 @@ class World(Task): device=self.device, ) + # Should not be necessary though, the autoregression is done + # in eval mode + sum_logits = sum_logits.detach() + average_logits = sum_logits / quizzes.numel() + # It's a bit brutal to do it twice, we should probably have a + # moving average and apply it right away + if desired_average_logits is not None: temperature = average_logits / desired_average_logits masked_inplace_autoregression(