for input in task.batches(split="train"):
assert input.dim() == 2 and input.dtype == torch.int64
for x in input:
train_examples[x.sum().item()] = x
for input in task.batches(split="train"):
assert input.dim() == 2 and input.dtype == torch.int64
for x in input:
train_examples[x.sum().item()] = x