X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mi_estimator.py;fp=mi_estimator.py;h=f8b859dd6fc07dc69f2b8fecfa2c30f4d1eeace3;hb=236238fdfe7d65612b58fbbb5bb29cff4ec45d54;hp=0000000000000000000000000000000000000000;hpb=f07dc15e422fd58a38c5a2ea3b260cd2b44e21af;p=pytorch.git diff --git a/mi_estimator.py b/mi_estimator.py new file mode 100755 index 0000000..f8b859d --- /dev/null +++ b/mi_estimator.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python + +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### + +import argparse, math, sys +from copy import deepcopy + +import torch, torchvision + +from torch import nn +import torch.nn.functional as F + +###################################################################### + +if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + +parser = argparse.ArgumentParser( + description = '''An implementation of a Mutual Information estimator with a deep model + +Three different toy data-sets are implemented: + + (1) Two MNIST images of same class. The "true" MI is the log of the + number of used MNIST classes. + + (2) One MNIST image and a pair of real numbers whose difference is + the class of the image. The "true" MI is the log of the number of + used MNIST classes. + + (3) Two 1d sequences, the first with a single peak, the second with + two peaks, and the height of the peak in the first is the + difference of timing of the peaks in the second. The "true" MI is + the log of the number of possible peak heights.''', + + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--data', + type = str, default = 'image_pair', + help = 'What data: image_pair, image_values_pair, sequence_pair') + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +parser.add_argument('--mnist_classes', + type = str, default = '0, 1, 3, 5, 6, 7, 8, 9', + help = 'What MNIST classes to use') + +parser.add_argument('--nb_classes', + type = int, default = 2, + help = 'How many classes for sequences') + +parser.add_argument('--nb_epochs', + type = int, default = 50, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 100, + help = 'Batch size') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Batch size') + +parser.add_argument('--independent', action = 'store_true', + help = 'Should the pair components be independent') + +###################################################################### + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device) + +###################################################################### + +def entropy(target): + probas = [] + for k in range(target.max() + 1): + n = (target == k).sum().item() + if n > 0: probas.append(n) + probas = torch.tensor(probas).float() + probas /= probas.sum() + return - (probas * probas.log()).sum().item() + +###################################################################### + +train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) +train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float() +train_target = train_set.train_labels.to(device) + +test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) +test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float() +test_target = test_set.test_labels.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +###################################################################### + +# Returns a triplet of tensors (a, b, c), where a and b contain each +# half of the samples, with a[i] and b[i] of same class for any i, and +# c is a 1d long tensor real classes + +def create_image_pairs(train = False): + ua, ub, uc = [], [], [] + + if train: + input, target = train_input, train_target + else: + input, target = test_input, test_target + + for i in used_MNIST_classes: + used_indices = torch.arange(input.size(0), device = target.device)\ + .masked_select(target == i.item()) + x = input[used_indices] + x = x[torch.randperm(x.size(0))] + hs = x.size(0)//2 + ua.append(x.narrow(0, 0, hs)) + ub.append(x.narrow(0, hs, hs)) + uc.append(target[used_indices]) + + a = torch.cat(ua, 0) + b = torch.cat(ub, 0) + c = torch.cat(uc, 0) + perm = torch.randperm(a.size(0)) + a = a[perm].contiguous() + + if args.independent: + perm = torch.randperm(a.size(0)) + b = b[perm].contiguous() + + return a, b, c + +###################################################################### + +# Returns a triplet a, b, c where a are the standard MNIST images, c +# the classes, and b is a Nx2 tensor, with for every n: +# +# b[n, 0] ~ Uniform(0, 10) +# b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n] + +def create_image_values_pairs(train = False): + ua, ub = [], [] + + if train: + input, target = train_input, train_target + else: + input, target = test_input, test_target + + m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device) + m[used_MNIST_classes] = 1 + m = m[target] + used_indices = torch.arange(input.size(0), device = target.device).masked_select(m) + + input = input[used_indices].contiguous() + target = target[used_indices].contiguous() + + a = input + c = target + + b = a.new(a.size(0), 2) + b[:, 0].uniform_(0.0, 10.0) + b[:, 1].uniform_(0.0, 0.5) + + if args.independent: + b[:, 1] += b[:, 0] + \ + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())] + else: + b[:, 1] += b[:, 0] + target.float() + + return a, b, c + +###################################################################### + +def create_sequences_pairs(train = False): + nb, length = 10000, 1024 + noise_level = 2e-2 + + ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1 + if args.independent: + hb = torch.randint(args.nb_classes, (nb, ), device = device) + else: + hb = ha + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) + a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + a = a - pos.view(nb, 1) + a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1) + a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes) + noise = a.new(a.size()).normal_(0, noise_level) + a = a + noise + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.5) + b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b1 = b1 - pos.view(nb, 1) + b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25 + pos = pos + hb.float() / (args.nb_classes + 1) * 0.5 + # pos += pos.new(hb.size()).uniform_(0.0, 0.01) + b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b2 = b2 - pos.view(nb, 1) + b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25 + + b = b1 + b2 + noise = b.new(b.size()).normal_(0, noise_level) + b = b + noise + + # a = (a - a.mean()) / a.std() + # b = (b - b.mean()) / b.std() + + return a, b, ha + +###################################################################### + +class NetForImagePair(nn.Module): + def __init__(self): + super(NetForImagePair, self).__init__() + self.features_a = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.features_b = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.fully_connected = nn.Sequential( + nn.Linear(256, 200), + nn.ReLU(), + nn.Linear(200, 1) + ) + + def forward(self, a, b): + a = self.features_a(a).view(a.size(0), -1) + b = self.features_b(b).view(b.size(0), -1) + x = torch.cat((a, b), 1) + return self.fully_connected(x) + +###################################################################### + +class NetForImageValuesPair(nn.Module): + def __init__(self): + super(NetForImageValuesPair, self).__init__() + self.features_a = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.features_b = nn.Sequential( + nn.Linear(2, 32), nn.ReLU(), + nn.Linear(32, 32), nn.ReLU(), + nn.Linear(32, 128), nn.ReLU(), + ) + + self.fully_connected = nn.Sequential( + nn.Linear(256, 200), + nn.ReLU(), + nn.Linear(200, 1) + ) + + def forward(self, a, b): + a = self.features_a(a).view(a.size(0), -1) + b = self.features_b(b).view(b.size(0), -1) + x = torch.cat((a, b), 1) + return self.fully_connected(x) + +###################################################################### + +class NetForSequencePair(nn.Module): + + def feature_model(self): + kernel_size = 11 + pooling_size = 4 + return nn.Sequential( + nn.Conv1d( 1, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + ) + + def __init__(self): + super(NetForSequencePair, self).__init__() + + self.nc = 32 + self.nh = 256 + + self.features_a = self.feature_model() + self.features_b = self.feature_model() + + self.fully_connected = nn.Sequential( + nn.Linear(2 * self.nc, self.nh), + nn.ReLU(), + nn.Linear(self.nh, 1) + ) + + def forward(self, a, b): + a = a.view(a.size(0), 1, a.size(1)) + a = self.features_a(a) + a = F.avg_pool1d(a, a.size(2)) + + b = b.view(b.size(0), 1, b.size(1)) + b = self.features_b(b) + b = F.avg_pool1d(b, b.size(2)) + + x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1) + return self.fully_connected(x) + +###################################################################### + +if args.data == 'image_pair': + create_pairs = create_image_pairs + model = NetForImagePair() + +elif args.data == 'image_values_pair': + create_pairs = create_image_values_pairs + model = NetForImageValuesPair() + +elif args.data == 'sequence_pair': + create_pairs = create_sequences_pairs + model = NetForSequencePair() + + ###################### + ## Save for figures + a, b, c = create_pairs() + for k in range(10): + file = open(f'train_{k:02d}.dat', 'w') + for i in range(a.size(1)): + file.write(f'{a[k, i]:f} {b[k,i]:f}\n') + file.close() + ###################### + +else: + raise Exception('Unknown data ' + args.data) + +###################################################################### +# Train + +print(f'nb_parameters {sum(x.numel() for x in model.parameters())}') + +model.to(device) + +input_a, input_b, classes = create_pairs(train = True) + +for e in range(args.nb_epochs): + + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + + input_br = input_b[torch.randperm(input_b.size(0))] + + acc_mi = 0.0 + + for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() + loss = - mi + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_mi /= (input_a.size(0) // args.batch_size) + + print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}') + + sys.stdout.flush() + +###################################################################### +# Test + +input_a, input_b, classes = create_pairs(train = False) + +input_br = input_b[torch.randperm(input_b.size(0))] + +acc_mi = 0.0 + +for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() + +acc_mi /= (input_a.size(0) // args.batch_size) + +print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}') + +######################################################################