X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=mine_mnist.py;h=7845d813ae20a86ed14f7fb5dd29970841c9397f;hp=c6dc287c5b4b99ab36284da6ca712d5cdeda8ad7;hb=4817708db50a18242ade3ba88971dd4ef0a73004;hpb=663ddb29ecd584102f5a19eefc686b7d5ed77d3e diff --git a/mine_mnist.py b/mine_mnist.py index c6dc287..7845d81 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -1,137 +1,261 @@ #!/usr/bin/env python -# @XREMOTE_HOST: elk.fleuret.org -# @XREMOTE_EXEC: ~/conda/bin/python -# @XREMOTE_PRE: ln -s ~/data/pytorch ./data -# @XREMOTE_PRE: killall -q -9 python || true +import argparse, math, sys -import math, sys, torch, torchvision +import torch, torchvision from torch import nn -from torch.nn import functional as F ###################################################################### -# Returns a pair of tensors (a, b, c), where a and b are Nx1x28x28 -# tensors containing images, with a[i] and b[i] of same class for any -# i, and c is a 1d long tensor with the count of pairs per class used. +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + +parser = argparse.ArgumentParser( + description = 'An implementation of Mutual Information estimator with a deep model', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--data', + type = str, default = 'image_pair', + help = 'What data') + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +parser.add_argument('--mnist_classes', + type = str, default = '0, 1, 3, 5, 6, 7, 8, 9', + help = 'What MNIST classes to use') + +###################################################################### + +def entropy(target): + probas = [] + for k in range(target.max() + 1): + n = (target == k).sum().item() + if n > 0: probas.append(n) + probas = torch.tensor(probas).float() + probas /= probas.sum() + return - (probas * probas.log()).sum().item() + +###################################################################### + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device) + +###################################################################### + +train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) +train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float() +train_target = train_set.train_labels.to(device) + +test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) +test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float() +test_target = test_set.test_labels.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +###################################################################### -def create_pair_set(used_classes, input, target): - u = [] +# Returns a triplet of tensors (a, b, c), where a and b contain each +# half of the samples, with a[i] and b[i] of same class for any i, and +# c is a 1d long tensor real classes - for i in used_classes: +def create_image_pairs(train = False): + ua, ub = [], [] + + if train: + input, target = train_input, train_target + else: + input, target = test_input, test_target + + for i in used_MNIST_classes: used_indices = torch.arange(input.size(0), device = target.device)\ .masked_select(target == i.item()) x = input[used_indices] x = x[torch.randperm(x.size(0))] - # Careful with odd numbers of samples in a class - x = x[0:2 * (x.size(0) // 2)].reshape(-1, 2, 28, 28) - u.append(x) + hs = x.size(0)//2 + ua.append(x.narrow(0, 0, hs)) + ub.append(x.narrow(0, hs, hs)) + uc.append(target[used_indices]) - x = torch.cat(u, 0) - x = x[torch.randperm(x.size(0))] - c = torch.tensor([x.size(0) for x in u]) + a = torch.cat(ua, 0) + b = torch.cat(ub, 0) + c = torch.cat(uc, 0) + perm = torch.randperm(a.size(0)) + a = a[perm].contiguous() + b = b[perm].contiguous() - return x.narrow(1, 0, 1).contiguous(), x.narrow(1, 1, 1).contiguous(), c + return a, b, c ###################################################################### -class Net(nn.Module): +# Returns a triplet a, b, c where a are the standard MNIST images, c +# the classes, and b is a Nx2 tensor, eith for every n: +# +# b[n, 0] ~ Uniform(0, 10) +# b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n] + +def create_image_values_pairs(train = False): + ua, ub = [], [] + + if train: + input, target = train_input, train_target + else: + input, target = test_input, test_target + + m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device) + m[used_MNIST_classes] = 1 + m = m[target] + used_indices = torch.arange(input.size(0), device = target.device).masked_select(m) + + input = input[used_indices].contiguous() + target = target[used_indices].contiguous() + + a = input + c = target + + b = a.new(a.size(0), 2) + b[:, 0].uniform_(10) + b[:, 1].uniform_(0.5) + b[:, 1] += b[:, 0] + target.float() + + return a, b, c + +###################################################################### + +class NetForImagePair(nn.Module): def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(2, 32, kernel_size = 5) - self.conv2 = nn.Conv2d(32, 64, kernel_size = 5) - self.fc1 = nn.Linear(256, 200) - self.fc2 = nn.Linear(200, 1) + super(NetForImagePair, self).__init__() + self.features_a = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.features_b = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.fully_connected = nn.Sequential( + nn.Linear(256, 200), + nn.ReLU(), + nn.Linear(200, 1) + ) def forward(self, a, b): + a = self.features_a(a).view(a.size(0), -1) + b = self.features_b(b).view(b.size(0), -1) x = torch.cat((a, b), 1) - x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3)) - x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2)) - x = x.view(x.size(0), -1) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x + return self.fully_connected(x) ###################################################################### -train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) -train_input = train_set.train_data.view(-1, 1, 28, 28).float() -train_target = train_set.train_labels - -test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) -test_input = test_set.test_data.view(-1, 1, 28, 28).float() -test_target = test_set.test_labels +class NetForImageValuesPair(nn.Module): + def __init__(self): + super(NetForImageValuesPair, self).__init__() + self.features_a = nn.Sequential( + nn.Conv2d(1, 16, kernel_size = 5), + nn.MaxPool2d(3), nn.ReLU(), + nn.Conv2d(16, 32, kernel_size = 5), + nn.MaxPool2d(2), nn.ReLU(), + ) + + self.features_b = nn.Sequential( + nn.Linear(2, 32), nn.ReLU(), + nn.Linear(32, 32), nn.ReLU(), + nn.Linear(32, 128), nn.ReLU(), + ) + + self.fully_connected = nn.Sequential( + nn.Linear(256, 200), + nn.ReLU(), + nn.Linear(200, 1) + ) -mu, std = train_input.mean(), train_input.std() -train_input.sub_(mu).div_(std) -test_input.sub_(mu).div_(std) + def forward(self, a, b): + a = self.features_a(a).view(a.size(0), -1) + b = self.features_b(b).view(b.size(0), -1) + x = torch.cat((a, b), 1) + return self.fully_connected(x) ###################################################################### -# The information bound is the log of the number of classes in there +if args.data == 'image_pair': + create_pairs = create_image_pairs + model = NetForImagePair() +elif args.data == 'image_values_pair': + create_pairs = create_image_values_pairs + model = NetForImageValuesPair() +else: + raise Exception('Unknown data ' + args.data) -# used_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9]) -used_classes = torch.tensor([ 3, 4, 7, 0 ]) +###################################################################### nb_epochs, batch_size = 50, 100 -model = Net() +print('nb_parameters %d' % sum(x.numel() for x in model.parameters())) + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) -if torch.cuda.is_available(): - model.cuda() - train_input, train_target = train_input.cuda(), train_target.cuda() - test_input, test_target = test_input.cuda(), test_target.cuda() +model.to(device) for e in range(nb_epochs): - input_a, input_b, count = create_pair_set(used_classes, train_input, train_target) - - class_proba = count.float() - class_proba /= class_proba.sum() - class_entropy = - (class_proba.log() * class_proba).sum().item() + input_a, input_b, classes = create_pairs(train = True) input_br = input_b[torch.randperm(input_b.size(0))] - mi = 0.0 + acc_mi = 0.0 for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), input_b.split(batch_size), input_br.split(batch_size)): - loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()) - mi -= loss.item() + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + loss = - mi + acc_mi += mi.item() optimizer.zero_grad() loss.backward() optimizer.step() - mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // batch_size) - print('%d %.04f %.04f'%(e, mi / math.log(2), class_entropy / math.log(2))) + print('%d %.04f %.04f' % (e, acc_mi / math.log(2), entropy(classes) / math.log(2))) sys.stdout.flush() ###################################################################### -input_a, input_b, count = create_pair_set(used_classes, test_input, test_target) +input_a, input_b, classes = create_pairs(train = False) for e in range(nb_epochs): - class_proba = count.float() - class_proba /= class_proba.sum() - class_entropy = - (class_proba.log() * class_proba).sum().item() - input_br = input_b[torch.randperm(input_b.size(0))] - mi = 0.0 + acc_mi = 0.0 for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), input_b.split(batch_size), input_br.split(batch_size)): - loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()) - mi -= loss.item() + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() - mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // batch_size) -print('test %.04f %.04f'%(mi / math.log(2), class_entropy / math.log(2))) +print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2))) ######################################################################