+def entropy(target):
+ probas = []
+ for k in range(target.max() + 1):
+ n = (target == k).sum().item()
+ if n > 0: probas.append(n)
+ probas = torch.tensor(probas).float()
+ probas /= probas.sum()
+ return - (probas * probas.log()).sum().item()
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
+train_target = train_set.train_labels.to(device)
+
+test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
+test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
+test_target = test_set.test_labels.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+# Returns a triplet of tensors (a, b, c), where a and b contain each
+# half of the samples, with a[i] and b[i] of same class for any i, and
+# c is a 1d long tensor real classes
+
+def create_image_pairs(train = False):
+ ua, ub = [], []
+
+ if train:
+ input, target = train_input, train_target
+ else:
+ input, target = test_input, test_target
+
+ for i in used_MNIST_classes: