X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=64fe96763f16033fe1241ef91a0b677d13b96f44;hb=fc1de19bf86b2cfd09264dfc6fbda1937248a40a;hp=5edb472024342565801fee1f7de48f55bb17756a;hpb=702e672dcf9ebcfad11ae4034e64117f2c67ead5;p=culture.git diff --git a/tasks.py b/tasks.py index 5edb472..64fe967 100755 --- a/tasks.py +++ b/tasks.py @@ -154,6 +154,9 @@ class World(Task): self.nb_batch_samples_world = input.size(0) self.nb_batch_samples_quizzes = 0 + # Shuffle + input = input[torch.randperm(input.size(0))] + if desc is None: desc = f"epoch-{split}" for batch in tqdm.tqdm(