+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+used_MNIST_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9])
+# used_MNIST_classes = torch.tensor([ 0, 9, 7 ])
+# used_MNIST_classes = torch.tensor([ 3, 4, 7, 0 ])
+
+######################################################################
+
+# Returns a triplet of tensors (a, b, c), where a and b contain each
+# half of the samples, with a[i] and b[i] of same class for any i, and
+# c is a 1d long tensor with the count of pairs per class used.
+
+def create_MNIST_pair_set(train = False):
+ ua, ub = [], []
+
+ if train:
+ input, target = train_input, train_target
+ else:
+ input, target = test_input, test_target
+
+ for i in used_MNIST_classes: