#!/usr/bin/env python
+import argparse
+
import math, sys, torch, torchvision
from torch import nn
######################################################################
+parser = argparse.ArgumentParser(
+ description = 'An implementation of Mutual Information estimator with a deep model',
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--data',
+ type = str, default = 'image_pair',
+ help = 'What data')
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--mnist_classes',
+ type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
+ help = 'What MNIST classes to use')
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'))
+
+######################################################################
+
train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
train_input = train_set.train_data.view(-1, 1, 28, 28).float()
train_target = train_set.train_labels
test_input = test_set.test_data.view(-1, 1, 28, 28).float()
test_target = test_set.test_labels
+if torch.cuda.is_available():
+ used_MNIST_classes = used_MNIST_classes.cuda()
+ train_input, train_target = train_input.cuda(), train_target.cuda()
+ test_input, test_target = test_input.cuda(), test_target.cuda()
+
mu, std = train_input.mean(), train_input.std()
train_input.sub_(mu).div_(std)
test_input.sub_(mu).div_(std)
-used_MNIST_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9])
-# used_MNIST_classes = torch.tensor([ 0, 9, 7 ])
-# used_MNIST_classes = torch.tensor([ 3, 4, 7, 0 ])
-
######################################################################
# Returns a triplet of tensors (a, b, c), where a and b contain each
# half of the samples, with a[i] and b[i] of same class for any i, and
# c is a 1d long tensor with the count of pairs per class used.
-def create_MNIST_pair_set(train = False):
+def create_image_pairs(train = False):
ua, ub = [], []
if train:
######################################################################
-class Net(nn.Module):
+def create_image_values_pairs(train = False):
+ ua, ub = [], []
+
+ if train:
+ input, target = train_input, train_target
+ else:
+ input, target = test_input, test_target
+
+ m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
+ m[used_MNIST_classes] = 1
+ m = m[target]
+ used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
+
+ input = input[used_indices].contiguous()
+ target = target[used_indices].contiguous()
+
+ a = input
+
+ b = a.new(a.size(0), 2)
+ b[:, 0].uniform_(10)
+ b[:, 1].uniform_(0.5)
+ b[:, 1] += b[:, 0] + target.float()
+
+ c = torch.tensor([(target == k).sum().item() for k in used_MNIST_classes])
+
+ return a, b, c
+
+######################################################################
+
+class NetImagePair(nn.Module):
def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(2, 32, kernel_size = 5)
- self.conv2 = nn.Conv2d(32, 64, kernel_size = 5)
- self.fc1 = nn.Linear(256, 200)
- self.fc2 = nn.Linear(200, 1)
+ super(NetImagePair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
def forward(self, a, b):
- # Make the two images a single two-channel image
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
x = torch.cat((a, b), 1)
- x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3))
- x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2))
- x = x.view(x.size(0), -1)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return x
+ return self.fully_connected(x)
######################################################################
-nb_epochs, batch_size = 50, 100
+class NetImageValuesPair(nn.Module):
+ def __init__(self):
+ super(NetImageValuesPair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Linear(2, 32), nn.ReLU(),
+ nn.Linear(32, 32), nn.ReLU(),
+ nn.Linear(32, 128), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
-model = Net()
+ def forward(self, a, b):
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
+ x = torch.cat((a, b), 1)
+ return self.fully_connected(x)
+
+######################################################################
+
+if args.data == 'image_pair':
+ create_pairs = create_image_pairs
+ model = NetImagePair()
+elif args.data == 'image_values_pair':
+ create_pairs = create_image_values_pairs
+ model = NetImageValuesPair()
+else:
+ raise Exception('Unknown data ' + args.data)
+
+######################################################################
+
+nb_epochs, batch_size = 50, 100
print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
if torch.cuda.is_available():
model.cuda()
- train_input, train_target = train_input.cuda(), train_target.cuda()
- test_input, test_target = test_input.cuda(), test_target.cuda()
for e in range(nb_epochs):
- input_a, input_b, count = create_MNIST_pair_set(train = True)
+ input_a, input_b, count = create_pairs(train = True)
# The information bound is the entropy of the class distribution
class_proba = count.float()
######################################################################
-input_a, input_b, count = create_MNIST_pair_set(train = False)
+input_a, input_b, count = create_pairs(train = False)
for e in range(nb_epochs):
class_proba = count.float()