-# Returns a pair of tensors (x, c), where x is a Nx2x28x28 containing
-# pairs of images of same classes (one per channel), and p is a 1d
-# long tensor with the count of pairs per class used
+train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+train_input = train_set.train_data.view(-1, 1, 28, 28).float()
+train_target = train_set.train_labels
+
+test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
+test_input = test_set.test_data.view(-1, 1, 28, 28).float()
+test_target = test_set.test_labels
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+used_MNIST_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9])
+# used_MNIST_classes = torch.tensor([ 0, 9, 7 ])
+# used_MNIST_classes = torch.tensor([ 3, 4, 7, 0 ])
+
+######################################################################
+
+# Returns a triplet of tensors (a, b, c), where a and b contain each
+# half of the samples, with a[i] and b[i] of same class for any i, and
+# c is a 1d long tensor with the count of pairs per class used.
+
+def create_MNIST_pair_set(train = False):
+ ua, ub = [], []