nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
+ shuffle=True,
device=device,
)
args.max_percents_of_test_in_train = 0
class TaskFromFile(Task):
- def tensorize(self, pairs):
+ def tensorize(self, pairs, shuffle):
len_max = max([len(x[0]) for x in pairs])
input = torch.cat(
0,
).to("cpu")
+ if shuffle:
+ print("SHUFFLING!")
+ i = torch.randperm(input.size(0))
+ input = input[i].contiguous()
+ pred_mask = pred_mask[i].contiguous()
+
return input, pred_mask
# trim all the tensors in the tuple z to remove as much token from
nb_train_samples,
nb_test_samples,
batch_size,
+ shuffle=False,
device=torch.device("cpu"),
):
self.batch_size = batch_size
self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
self.id2char = dict([(n, c) for c, n in self.char2id.items()])
- self.train_input, self.train_pred_masks = self.tensorize(train_pairs)
- self.test_input, self.test_pred_masks = self.tensorize(test_pairs)
+ self.train_input, self.train_pred_masks = self.tensorize(
+ train_pairs, shuffle=shuffle
+ )
+ self.test_input, self.test_pred_masks = self.tensorize(
+ test_pairs, shuffle=shuffle
+ )
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}