Update.
[pytorch.git] / mine_mnist.py
index c22d7fe..f8b859d 100755 (executable)
@@ -1,5 +1,22 @@
 #!/usr/bin/env python
 
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
+#########################################################################
+
 import argparse, math, sys
 from copy import deepcopy
 
@@ -11,21 +28,36 @@ import torch.nn.functional as F
 ######################################################################
 
 if torch.cuda.is_available():
-    device = torch.device('cuda')
     torch.backends.cudnn.benchmark = True
+    device = torch.device('cuda')
 else:
     device = torch.device('cpu')
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = 'An implementation of Mutual Information estimator with a deep model',
+    description = '''An implementation of a Mutual Information estimator with a deep model
+
+Three different toy data-sets are implemented:
+
+ (1) Two MNIST images of same class. The "true" MI is the log of the
+     number of used MNIST classes.
+
+ (2) One MNIST image and a pair of real numbers whose difference is
+     the class of the image. The "true" MI is the log of the number of
+     used MNIST classes.
+
+ (3) Two 1d sequences, the first with a single peak, the second with
+     two peaks, and the height of the peak in the first is the
+     difference of timing of the peaks in the second. The "true" MI is
+     the log of the number of possible peak heights.''',
+
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
 parser.add_argument('--data',
                     type = str, default = 'image_pair',
-                    help = 'What data')
+                    help = 'What data: image_pair, image_values_pair, sequence_pair')
 
 parser.add_argument('--seed',
                     type = int, default = 0,
@@ -35,16 +67,24 @@ parser.add_argument('--mnist_classes',
                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
                     help = 'What MNIST classes to use')
 
-######################################################################
+parser.add_argument('--nb_classes',
+                    type = int, default = 2,
+                    help = 'How many classes for sequences')
 
-def entropy(target):
-    probas = []
-    for k in range(target.max() + 1):
-        n = (target == k).sum().item()
-        if n > 0: probas.append(n)
-    probas = torch.tensor(probas).float()
-    probas /= probas.sum()
-    return - (probas * probas.log()).sum().item()
+parser.add_argument('--nb_epochs',
+                    type = int, default = 50,
+                    help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+                    type = int, default = 100,
+                    help = 'Batch size')
+
+parser.add_argument('--learning_rate',
+                    type = float, default = 1e-3,
+                    help = 'Batch size')
+
+parser.add_argument('--independent', action = 'store_true',
+                    help = 'Should the pair components be independent')
 
 ######################################################################
 
@@ -57,6 +97,17 @@ used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device =
 
 ######################################################################
 
+def entropy(target):
+    probas = []
+    for k in range(target.max() + 1):
+        n = (target == k).sum().item()
+        if n > 0: probas.append(n)
+    probas = torch.tensor(probas).float()
+    probas /= probas.sum()
+    return - (probas * probas.log()).sum().item()
+
+######################################################################
+
 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
 train_target = train_set.train_labels.to(device)
@@ -98,6 +149,9 @@ def create_image_pairs(train = False):
     c = torch.cat(uc, 0)
     perm = torch.randperm(a.size(0))
     a = a[perm].contiguous()
+
+    if args.independent:
+        perm = torch.randperm(a.size(0))
     b = b[perm].contiguous()
 
     return a, b, c
@@ -105,7 +159,7 @@ def create_image_pairs(train = False):
 ######################################################################
 
 # Returns a triplet a, b, c where a are the standard MNIST images, c
-# the classes, and b is a Nx2 tensor, eith for every n:
+# the classes, and b is a Nx2 tensor, with for every n:
 #
 #   b[n, 0] ~ Uniform(0, 10)
 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
@@ -130,9 +184,14 @@ def create_image_values_pairs(train = False):
     c = target
 
     b = a.new(a.size(0), 2)
-    b[:, 0].uniform_(10)
-    b[:, 1].uniform_(0.5)
-    b[:, 1] += b[:, 0] + target.float()
+    b[:, 0].uniform_(0.0, 10.0)
+    b[:, 1].uniform_(0.0, 0.5)
+
+    if args.independent:
+        b[:, 1] += b[:, 0] + \
+                   used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+    else:
+        b[:, 1] += b[:, 0] + target.float()
 
     return a, b, c
 
@@ -140,45 +199,38 @@ def create_image_values_pairs(train = False):
 
 def create_sequences_pairs(train = False):
     nb, length = 10000, 1024
-    noise_level = 1e-2
+    noise_level = 2e-2
 
-    nb_classes = 4
-    ha = torch.randint(nb_classes, (nb, ), device = device) + 1
-    # hb = torch.randint(nb_classes, (nb, ), device = device)
-    hb = ha
+    ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
+    if args.independent:
+        hb = torch.randint(args.nb_classes, (nb, ), device = device)
+    else:
+        hb = ha
 
     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
     a = a - pos.view(nb, 1)
     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
-    a = a * ha.float().view(-1, 1).expand_as(a) / (1 + nb_classes)
+    a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
     noise = a.new(a.size()).normal_(0, noise_level)
     a = a + noise
 
-    pos = torch.empty(nb, device = device).uniform_(0.5)
+    pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
     b1 = b1 - pos.view(nb, 1)
-    b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1)
-    pos = pos + hb.float() / (nb_classes + 1) * 0.5
+    b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
+    pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
+    # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
     b2 = b2 - pos.view(nb, 1)
-    b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1)
+    b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
 
     b = b1 + b2
     noise = b.new(b.size()).normal_(0, noise_level)
     b = b + noise
 
-    ######################################################################
-    # for k in range(10):
-        # file = open(f'/tmp/dat{k:02d}', 'w')
-        # for i in range(a.size(1)):
-            # file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
-        # file.close()
-    # exit(0)
-    ######################################################################
-
-    a = (a - a.mean()) / a.std()
-    b = (b - b.mean()) / b.std()
+    # a = (a - a.mean()) / a.std()
+    # b = (b - b.mean()) / b.std()
 
     return a, b, ha
 
@@ -248,17 +300,21 @@ class NetForImageValuesPair(nn.Module):
 class NetForSequencePair(nn.Module):
 
     def feature_model(self):
+        kernel_size = 11
+        pooling_size = 4
         return  nn.Sequential(
-            nn.Conv1d(1, self.nc, kernel_size = 5),
-            nn.MaxPool1d(2), nn.ReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
-            nn.MaxPool1d(2), nn.ReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
-            nn.MaxPool1d(2), nn.ReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
-            nn.MaxPool1d(2), nn.ReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = 5),
-            nn.MaxPool1d(2), nn.ReLU(),
+            nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
+            nn.AvgPool1d(pooling_size),
+            nn.LeakyReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.AvgPool1d(pooling_size),
+            nn.LeakyReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.AvgPool1d(pooling_size),
+            nn.LeakyReLU(),
+            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.AvgPool1d(pooling_size),
+            nn.LeakyReLU(),
         )
 
     def __init__(self):
@@ -293,66 +349,78 @@ class NetForSequencePair(nn.Module):
 if args.data == 'image_pair':
     create_pairs = create_image_pairs
     model = NetForImagePair()
+
 elif args.data == 'image_values_pair':
     create_pairs = create_image_values_pairs
     model = NetForImageValuesPair()
+
 elif args.data == 'sequence_pair':
     create_pairs = create_sequences_pairs
     model = NetForSequencePair()
+
+    ######################
+    ## Save for figures
+    a, b, c = create_pairs()
+    for k in range(10):
+        file = open(f'train_{k:02d}.dat', 'w')
+        for i in range(a.size(1)):
+            file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
+        file.close()
+    ######################
+
 else:
     raise Exception('Unknown data ' + args.data)
 
 ######################################################################
+# Train
 
-nb_epochs, batch_size = 50, 100
-
-print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
-
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
 
 model.to(device)
 
-for e in range(nb_epochs):
+input_a, input_b, classes = create_pairs(train = True)
+
+for e in range(args.nb_epochs):
 
-    input_a, input_b, classes = create_pairs(train = True)
+    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
 
     input_br = input_b[torch.randperm(input_b.size(0))]
 
     acc_mi = 0.0
 
-    for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
-                                          input_b.split(batch_size),
-                                          input_br.split(batch_size)):
+    for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
+                                          input_b.split(args.batch_size),
+                                          input_br.split(args.batch_size)):
         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
-        loss = - mi
         acc_mi += mi.item()
+        loss = - mi
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
-    acc_mi /= (input_a.size(0) // batch_size)
+    acc_mi /= (input_a.size(0) // args.batch_size)
 
-    print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
+    print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
 
     sys.stdout.flush()
 
 ######################################################################
+# Test
 
 input_a, input_b, classes = create_pairs(train = False)
 
-for e in range(nb_epochs):
-    input_br = input_b[torch.randperm(input_b.size(0))]
+input_br = input_b[torch.randperm(input_b.size(0))]
 
-    acc_mi = 0.0
+acc_mi = 0.0
 
-    for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
-                                          input_b.split(batch_size),
-                                          input_br.split(batch_size)):
-        mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
-        acc_mi += mi.item()
+for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
+                                      input_b.split(args.batch_size),
+                                      input_br.split(args.batch_size)):
+    mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+    acc_mi += mi.item()
 
-    acc_mi /= (input_a.size(0) // batch_size)
+acc_mi /= (input_a.size(0) // args.batch_size)
 
-print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
+print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
 
 ######################################################################