Update.
[pytorch.git] / mine_mnist.py
index 412c624..c22d7fe 100755 (executable)
@@ -1,11 +1,20 @@
 #!/usr/bin/env python
 
-import argparse
+import argparse, math, sys
+from copy import deepcopy
 
-import math, sys, torch, torchvision
+import torch, torchvision
 
 from torch import nn
-from torch.nn import functional as F
+import torch.nn.functional as F
+
+######################################################################
+
+if torch.cuda.is_available():
+    device = torch.device('cuda')
+    torch.backends.cudnn.benchmark = True
+else:
+    device = torch.device('cpu')
 
 ######################################################################
 
@@ -28,27 +37,33 @@ parser.add_argument('--mnist_classes',
 
 ######################################################################
 
+def entropy(target):
+    probas = []
+    for k in range(target.max() + 1):
+        n = (target == k).sum().item()
+        if n > 0: probas.append(n)
+    probas = torch.tensor(probas).float()
+    probas /= probas.sum()
+    return - (probas * probas.log()).sum().item()
+
+######################################################################
+
 args = parser.parse_args()
 
 if args.seed >= 0:
     torch.manual_seed(args.seed)
 
-used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'))
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
 
 ######################################################################
 
 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
-train_input  = train_set.train_data.view(-1, 1, 28, 28).float()
-train_target = train_set.train_labels
+train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
+train_target = train_set.train_labels.to(device)
 
 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
-test_input = test_set.test_data.view(-1, 1, 28, 28).float()
-test_target = test_set.test_labels
-
-if torch.cuda.is_available():
-    used_MNIST_classes = used_MNIST_classes.cuda()
-    train_input, train_target = train_input.cuda(), train_target.cuda()
-    test_input, test_target = test_input.cuda(), test_target.cuda()
+test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
+test_target = test_set.test_labels.to(device)
 
 mu, std = train_input.mean(), train_input.std()
 train_input.sub_(mu).div_(std)
@@ -58,10 +73,10 @@ test_input.sub_(mu).div_(std)
 
 # Returns a triplet of tensors (a, b, c), where a and b contain each
 # half of the samples, with a[i] and b[i] of same class for any i, and
-# c is a 1d long tensor with the count of pairs per class used.
+# c is a 1d long tensor real classes
 
 def create_image_pairs(train = False):
-    ua, ub = [], []
+    ua, ub, uc = [], [], []
 
     if train:
         input, target = train_input, train_target
@@ -76,18 +91,25 @@ def create_image_pairs(train = False):
         hs = x.size(0)//2
         ua.append(x.narrow(0, 0, hs))
         ub.append(x.narrow(0, hs, hs))
+        uc.append(target[used_indices])
 
     a = torch.cat(ua, 0)
     b = torch.cat(ub, 0)
+    c = torch.cat(uc, 0)
     perm = torch.randperm(a.size(0))
     a = a[perm].contiguous()
     b = b[perm].contiguous()
-    c = torch.tensor([x.size(0) for x in ua])
 
     return a, b, c
 
 ######################################################################
 
+# Returns a triplet a, b, c where a are the standard MNIST images, c
+# the classes, and b is a Nx2 tensor, eith for every n:
+#
+#   b[n, 0] ~ Uniform(0, 10)
+#   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
+
 def create_image_values_pairs(train = False):
     ua, ub = [], []
 
@@ -105,21 +127,66 @@ def create_image_values_pairs(train = False):
     target = target[used_indices].contiguous()
 
     a = input
+    c = target
 
     b = a.new(a.size(0), 2)
     b[:, 0].uniform_(10)
     b[:, 1].uniform_(0.5)
     b[:, 1] += b[:, 0] + target.float()
 
-    c = torch.tensor([(target == k).sum().item() for k in used_MNIST_classes])
-
     return a, b, c
 
 ######################################################################
 
-class NetImagePair(nn.Module):
+def create_sequences_pairs(train = False):
+    nb, length = 10000, 1024
+    noise_level = 1e-2
+
+    nb_classes = 4
+    ha = torch.randint(nb_classes, (nb, ), device = device) + 1
+    # hb = torch.randint(nb_classes, (nb, ), device = device)
+    hb = ha
+
+    pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
+    a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    a = a - pos.view(nb, 1)
+    a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
+    a = a * ha.float().view(-1, 1).expand_as(a) / (1 + nb_classes)
+    noise = a.new(a.size()).normal_(0, noise_level)
+    a = a + noise
+
+    pos = torch.empty(nb, device = device).uniform_(0.5)
+    b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    b1 = b1 - pos.view(nb, 1)
+    b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1)
+    pos = pos + hb.float() / (nb_classes + 1) * 0.5
+    b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    b2 = b2 - pos.view(nb, 1)
+    b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1)
+
+    b = b1 + b2
+    noise = b.new(b.size()).normal_(0, noise_level)
+    b = b + noise
+
+    ######################################################################
+    # for k in range(10):
+        # file = open(f'/tmp/dat{k:02d}', 'w')
+        # for i in range(a.size(1)):
+            # file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
+        # file.close()
+    # exit(0)
+    ######################################################################
+
+    a = (a - a.mean()) / a.std()
+    b = (b - b.mean()) / b.std()
+
+    return a, b, ha
+
+######################################################################
+
+class NetForImagePair(nn.Module):
     def __init__(self):
-        super(NetImagePair, self).__init__()
+        super(NetForImagePair, self).__init__()
         self.features_a = nn.Sequential(
             nn.Conv2d(1, 16, kernel_size = 5),
             nn.MaxPool2d(3), nn.ReLU(),
@@ -148,9 +215,9 @@ class NetImagePair(nn.Module):
 
 ######################################################################
 
-class NetImageValuesPair(nn.Module):
+class NetForImageValuesPair(nn.Module):
     def __init__(self):
-        super(NetImageValuesPair, self).__init__()
+        super(NetForImageValuesPair, self).__init__()
         self.features_a = nn.Sequential(
             nn.Conv2d(1, 16, kernel_size = 5),
             nn.MaxPool2d(3), nn.ReLU(),
@@ -178,12 +245,60 @@ class NetImageValuesPair(nn.Module):
 
 ######################################################################
 
+class NetForSequencePair(nn.Module):
+
+    def feature_model(self):
+        return  nn.Sequential(
+            nn.Conv1d(1, self.nc, kernel_size = 5),
+            nn.MaxPool1d(2), nn.ReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
+            nn.MaxPool1d(2), nn.ReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
+            nn.MaxPool1d(2), nn.ReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
+            nn.MaxPool1d(2), nn.ReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
+            nn.MaxPool1d(2), nn.ReLU(),
+        )
+
+    def __init__(self):
+        super(NetForSequencePair, self).__init__()
+
+        self.nc = 32
+        self.nh = 256
+
+        self.features_a = self.feature_model()
+        self.features_b = self.feature_model()
+
+        self.fully_connected = nn.Sequential(
+            nn.Linear(2 * self.nc, self.nh),
+            nn.ReLU(),
+            nn.Linear(self.nh, 1)
+        )
+
+    def forward(self, a, b):
+        a = a.view(a.size(0), 1, a.size(1))
+        a = self.features_a(a)
+        a = F.avg_pool1d(a, a.size(2))
+
+        b = b.view(b.size(0), 1, b.size(1))
+        b = self.features_b(b)
+        b = F.avg_pool1d(b, b.size(2))
+
+        x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
+        return self.fully_connected(x)
+
+######################################################################
+
 if args.data == 'image_pair':
     create_pairs = create_image_pairs
-    model = NetImagePair()
+    model = NetForImagePair()
 elif args.data == 'image_values_pair':
     create_pairs = create_image_values_pairs
-    model = NetImageValuesPair()
+    model = NetForImageValuesPair()
+elif args.data == 'sequence_pair':
+    create_pairs = create_sequences_pairs
+    model = NetForSequencePair()
 else:
     raise Exception('Unknown data ' + args.data)
 
@@ -195,17 +310,11 @@ print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
 
 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
 
-if torch.cuda.is_available():
-    model.cuda()
+model.to(device)
 
 for e in range(nb_epochs):
 
-    input_a, input_b, count = create_pairs(train = True)
-
-    # The information bound is the entropy of the class distribution
-    class_proba = count.float()
-    class_proba /= class_proba.sum()
-    class_entropy = - (class_proba.log() * class_proba).sum().item()
+    input_a, input_b, classes = create_pairs(train = True)
 
     input_br = input_b[torch.randperm(input_b.size(0))]
 
@@ -223,19 +332,15 @@ for e in range(nb_epochs):
 
     acc_mi /= (input_a.size(0) // batch_size)
 
-    print('%d %.04f %.04f' % (e, acc_mi / math.log(2), class_entropy / math.log(2)))
+    print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
 
     sys.stdout.flush()
 
 ######################################################################
 
-input_a, input_b, count = create_pairs(train = False)
+input_a, input_b, classes = create_pairs(train = False)
 
 for e in range(nb_epochs):
-    class_proba = count.float()
-    class_proba /= class_proba.sum()
-    class_entropy = - (class_proba.log() * class_proba).sum().item()
-
     input_br = input_b[torch.randperm(input_b.size(0))]
 
     acc_mi = 0.0
@@ -248,6 +353,6 @@ for e in range(nb_epochs):
 
     acc_mi /= (input_a.size(0) // batch_size)
 
-print('test %.04f %.04f'%(acc_mi / math.log(2), class_entropy / math.log(2)))
+print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
 
 ######################################################################