+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--mnist_classes',
+ type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
+ help = 'What MNIST classes to use')
+
+######################################################################
+
+if torch.cuda.is_available():
+ device = torch.device('cuda')
+else:
+ device = torch.device('cpu')
+
+######################################################################
+
+def entropy(target):
+ probas = []
+ for k in range(target.max() + 1):
+ n = (target == k).sum().item()
+ if n > 0: probas.append(n)
+ probas = torch.tensor(probas).float()
+ probas /= probas.sum()
+ return - (probas * probas.log()).sum().item()
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
+train_target = train_set.train_labels.to(device)
+
+test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
+test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
+test_target = test_set.test_labels.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+# Returns a triplet of tensors (a, b, c), where a and b contain each
+# half of the samples, with a[i] and b[i] of same class for any i, and
+# c is a 1d long tensor real classes
+
+def create_image_pairs(train = False):
+ ua, ub = [], []
+
+ if train:
+ input, target = train_input, train_target
+ else:
+ input, target = test_input, test_target
+
+ for i in used_MNIST_classes: