-train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
-train_input = train_set.train_data.view(-1, 1, 28, 28).float()
-train_target = train_set.train_labels
-
-test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
-test_input = test_set.test_data.view(-1, 1, 28, 28).float()
-test_target = test_set.test_labels
-
-mu, std = train_input.mean(), train_input.std()
-train_input.sub_(mu).div_(std)
-test_input.sub_(mu).div_(std)
-
-######################################################################
-
-# The information bound is the log of the number of classes in there
-
-# used_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9])
-used_classes = torch.tensor([ 3, 4, 7, 0 ])
-