Update.
[culture.git] / tasks.py
index cdf8f9e..64fe967 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -154,6 +154,9 @@ class World(Task):
             self.nb_batch_samples_world = input.size(0)
             self.nb_batch_samples_quizzes = 0
 
             self.nb_batch_samples_world = input.size(0)
             self.nb_batch_samples_quizzes = 0
 
+        # Shuffle
+        input = input[torch.randperm(input.size(0))]
+
         if desc is None:
             desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
         if desc is None:
             desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
@@ -274,8 +277,15 @@ class World(Task):
             device=self.device,
         )
 
             device=self.device,
         )
 
+        # Should not be necessary though, the autoregression is done
+        # in eval mode
+        sum_logits = sum_logits.detach()
+
         average_logits = sum_logits / quizzes.numel()
 
         average_logits = sum_logits / quizzes.numel()
 
+        # It's a bit brutal to do it twice, we should probably have a
+        # moving average and apply it right away
+
         if desired_average_logits is not None:
             temperature = average_logits / desired_average_logits
             masked_inplace_autoregression(
         if desired_average_logits is not None:
             temperature = average_logits / desired_average_logits
             masked_inplace_autoregression(