X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=d21e2648466a606b3067fc680feb7305a6b95781;hb=8ea809c43242d3a2e063692105919a86c3f6fe6b;hp=e5d3a7e66a77c769d970da8c0268f7fee307b7b9;hpb=1eef58fd084437bbcd2041b946b468615e203dd8;p=picoclvr.git diff --git a/tasks.py b/tasks.py index e5d3a7e..d21e264 100755 --- a/tasks.py +++ b/tasks.py @@ -71,7 +71,7 @@ class Task: class TaskFromFile(Task): - def tensorize(self, pairs): + def tensorize(self, pairs, shuffle): len_max = max([len(x[0]) for x in pairs]) input = torch.cat( @@ -98,6 +98,12 @@ class TaskFromFile(Task): 0, ).to("cpu") + if shuffle: + print("SHUFFLING!") + i = torch.randperm(input.size(0)) + input = input[i].contiguous() + pred_mask = pred_mask[i].contiguous() + return input, pred_mask # trim all the tensors in the tuple z to remove as much token from @@ -122,6 +128,7 @@ class TaskFromFile(Task): nb_train_samples, nb_test_samples, batch_size, + shuffle=False, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -156,8 +163,12 @@ class TaskFromFile(Task): self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) - self.train_input, self.train_pred_masks = self.tensorize(train_pairs) - self.test_input, self.test_pred_masks = self.tensorize(test_pairs) + self.train_input, self.train_pred_masks = self.tensorize( + train_pairs, shuffle=shuffle + ) + self.test_input, self.test_pred_masks = self.tensorize( + test_pairs, shuffle=shuffle + ) def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"}