OCD update.
[pytorch.git] / mine_mnist.py
index 6f65136..7845d81 100755 (executable)
 #!/usr/bin/env python
 
-# @XREMOTE_HOST: elk.fleuret.org
-# @XREMOTE_EXEC: ~/conda/bin/python
-# @XREMOTE_PRE: ln -s ~/data/pytorch ./data
-# @XREMOTE_PRE: killall -q -9 python || true
+import argparse, math, sys
 
-import math, sys, torch, torchvision
+import torch, torchvision
 
 from torch import nn
-from torch.nn import functional as F
 
 ######################################################################
 
-# Returns a pair of tensors (a, b, c), where a and b are tensors
-# containing each half of the samples, with a[i] and b[i] of same
-# class for any i, and c is a 1d long tensor with the count of pairs
-# per class used.
+if torch.cuda.is_available():
+    device = torch.device('cuda')
+else:
+    device = torch.device('cpu')
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description = 'An implementation of Mutual Information estimator with a deep model',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--data',
+                    type = str, default = 'image_pair',
+                    help = 'What data')
+
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--mnist_classes',
+                    type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
+                    help = 'What MNIST classes to use')
+
+######################################################################
+
+def entropy(target):
+    probas = []
+    for k in range(target.max() + 1):
+        n = (target == k).sum().item()
+        if n > 0: probas.append(n)
+    probas = torch.tensor(probas).float()
+    probas /= probas.sum()
+    return - (probas * probas.log()).sum().item()
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
 
-def create_pair_set(used_classes, input, target):
+######################################################################
+
+train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
+train_target = train_set.train_labels.to(device)
+
+test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
+test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
+test_target = test_set.test_labels.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+# Returns a triplet of tensors (a, b, c), where a and b contain each
+# half of the samples, with a[i] and b[i] of same class for any i, and
+# c is a 1d long tensor real classes
+
+def create_image_pairs(train = False):
     ua, ub = [], []
 
-    for i in used_classes:
+    if train:
+        input, target = train_input, train_target
+    else:
+        input, target = test_input, test_target
+
+    for i in used_MNIST_classes:
         used_indices = torch.arange(input.size(0), device = target.device)\
                             .masked_select(target == i.item())
         x = input[used_indices]
         x = x[torch.randperm(x.size(0))]
-        ua.append(x.narrow(0, 0, x.size(0)//2))
-        ub.append(x.narrow(0, x.size(0)//2, x.size(0)//2))
+        hs = x.size(0)//2
+        ua.append(x.narrow(0, 0, hs))
+        ub.append(x.narrow(0, hs, hs))
+        uc.append(target[used_indices])
 
     a = torch.cat(ua, 0)
     b = torch.cat(ub, 0)
+    c = torch.cat(uc, 0)
     perm = torch.randperm(a.size(0))
     a = a[perm].contiguous()
     b = b[perm].contiguous()
-    c = torch.tensor([x.size(0) for x in ua])
 
     return a, b, c
 
 ######################################################################
 
-class Net(nn.Module):
+# Returns a triplet a, b, c where a are the standard MNIST images, c
+# the classes, and b is a Nx2 tensor, eith for every n:
+#
+#   b[n, 0] ~ Uniform(0, 10)
+#   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
+
+def create_image_values_pairs(train = False):
+    ua, ub = [], []
+
+    if train:
+        input, target = train_input, train_target
+    else:
+        input, target = test_input, test_target
+
+    m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
+    m[used_MNIST_classes] = 1
+    m = m[target]
+    used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
+
+    input = input[used_indices].contiguous()
+    target = target[used_indices].contiguous()
+
+    a = input
+    c = target
+
+    b = a.new(a.size(0), 2)
+    b[:, 0].uniform_(10)
+    b[:, 1].uniform_(0.5)
+    b[:, 1] += b[:, 0] + target.float()
+
+    return a, b, c
+
+######################################################################
+
+class NetForImagePair(nn.Module):
     def __init__(self):
-        super(Net, self).__init__()
-        self.conv1 = nn.Conv2d(2, 32, kernel_size = 5)
-        self.conv2 = nn.Conv2d(32, 64, kernel_size = 5)
-        self.fc1 = nn.Linear(256, 200)
-        self.fc2 = nn.Linear(200, 1)
+        super(NetForImagePair, self).__init__()
+        self.features_a = nn.Sequential(
+            nn.Conv2d(1, 16, kernel_size = 5),
+            nn.MaxPool2d(3), nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size = 5),
+            nn.MaxPool2d(2), nn.ReLU(),
+        )
+
+        self.features_b = nn.Sequential(
+            nn.Conv2d(1, 16, kernel_size = 5),
+            nn.MaxPool2d(3), nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size = 5),
+            nn.MaxPool2d(2), nn.ReLU(),
+        )
+
+        self.fully_connected = nn.Sequential(
+            nn.Linear(256, 200),
+            nn.ReLU(),
+            nn.Linear(200, 1)
+        )
 
     def forward(self, a, b):
+        a = self.features_a(a).view(a.size(0), -1)
+        b = self.features_b(b).view(b.size(0), -1)
         x = torch.cat((a, b), 1)
-        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3))
-        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2))
-        x = x.view(x.size(0), -1)
-        x = F.relu(self.fc1(x))
-        x = self.fc2(x)
-        return x
+        return self.fully_connected(x)
 
 ######################################################################
 
-train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
-train_input  = train_set.train_data.view(-1, 1, 28, 28).float()
-train_target = train_set.train_labels
-
-test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
-test_input = test_set.test_data.view(-1, 1, 28, 28).float()
-test_target = test_set.test_labels
+class NetForImageValuesPair(nn.Module):
+    def __init__(self):
+        super(NetForImageValuesPair, self).__init__()
+        self.features_a = nn.Sequential(
+            nn.Conv2d(1, 16, kernel_size = 5),
+            nn.MaxPool2d(3), nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size = 5),
+            nn.MaxPool2d(2), nn.ReLU(),
+        )
+
+        self.features_b = nn.Sequential(
+            nn.Linear(2, 32), nn.ReLU(),
+            nn.Linear(32, 32), nn.ReLU(),
+            nn.Linear(32, 128), nn.ReLU(),
+        )
+
+        self.fully_connected = nn.Sequential(
+            nn.Linear(256, 200),
+            nn.ReLU(),
+            nn.Linear(200, 1)
+        )
 
-mu, std = train_input.mean(), train_input.std()
-train_input.sub_(mu).div_(std)
-test_input.sub_(mu).div_(std)
+    def forward(self, a, b):
+        a = self.features_a(a).view(a.size(0), -1)
+        b = self.features_b(b).view(b.size(0), -1)
+        x = torch.cat((a, b), 1)
+        return self.fully_connected(x)
 
 ######################################################################
 
-# The information bound is the log of the number of classes in there
+if args.data == 'image_pair':
+    create_pairs = create_image_pairs
+    model = NetForImagePair()
+elif args.data == 'image_values_pair':
+    create_pairs = create_image_values_pairs
+    model = NetForImageValuesPair()
+else:
+    raise Exception('Unknown data ' + args.data)
 
-# used_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9])
-used_classes = torch.tensor([ 3, 4, 7, 0 ])
+######################################################################
 
 nb_epochs, batch_size = 50, 100
 
-model = Net()
+print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
+
 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
 
-if torch.cuda.is_available():
-    model.cuda()
-    train_input, train_target = train_input.cuda(), train_target.cuda()
-    test_input, test_target = test_input.cuda(), test_target.cuda()
+model.to(device)
 
 for e in range(nb_epochs):
 
-    input_a, input_b, count = create_pair_set(used_classes, train_input, train_target)
-
-    class_proba = count.float()
-    class_proba /= class_proba.sum()
-    class_entropy = - (class_proba.log() * class_proba).sum().item()
+    input_a, input_b, classes = create_pairs(train = True)
 
     input_br = input_b[torch.randperm(input_b.size(0))]
 
@@ -111,19 +235,15 @@ for e in range(nb_epochs):
 
     acc_mi /= (input_a.size(0) // batch_size)
 
-    print('%d %.04f %.04f'%(e, acc_mi / math.log(2), class_entropy / math.log(2)))
+    print('%d %.04f %.04f' % (e, acc_mi / math.log(2), entropy(classes) / math.log(2)))
 
     sys.stdout.flush()
 
 ######################################################################
 
-input_a, input_b, count = create_pair_set(used_classes, test_input, test_target)
+input_a, input_b, classes = create_pairs(train = False)
 
 for e in range(nb_epochs):
-    class_proba = count.float()
-    class_proba /= class_proba.sum()
-    class_entropy = - (class_proba.log() * class_proba).sum().item()
-
     input_br = input_b[torch.randperm(input_b.size(0))]
 
     acc_mi = 0.0
@@ -131,11 +251,11 @@ for e in range(nb_epochs):
     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
                                           input_b.split(batch_size),
                                           input_br.split(batch_size)):
-        loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
-        acc_mi -= loss.item()
+        mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+        acc_mi += mi.item()
 
     acc_mi /= (input_a.size(0) // batch_size)
 
-print('test %.04f %.04f'%(acc_mi / math.log(2), class_entropy / math.log(2)))
+print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
 
 ######################################################################