X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mine_mnist.py;h=5ab427fd7ec684f439cfa1c377e727968f2c3040;hb=1ef32d0b85a179a36a1cd7899b8301bcb8e563d2;hp=82f6530592f71c1d8b734e823725b208b87b4908;hpb=7abf09dfdb0059f0a0f4d4fcb5892f030ee75e4e;p=pytorch.git diff --git a/mine_mnist.py b/mine_mnist.py index 82f6530..5ab427f 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -1,10 +1,5 @@ #!/usr/bin/env python -# @XREMOTE_HOST: elk.fleuret.org -# @XREMOTE_EXEC: ~/conda/bin/python -# @XREMOTE_PRE: ln -s ~/data/pytorch ./data -# @XREMOTE_PRE: killall -q -9 python || true - import math, sys, torch, torchvision from torch import nn @@ -12,27 +7,53 @@ from torch.nn import functional as F ###################################################################### -# Returns a pair of tensors (a, b, c), where a and b are Nx1x28x28 -# tensors containing images, with a[i] and b[i] of same class for any -# i, and c is a 1d long tensor with the count of pairs per class used. +train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) +train_input = train_set.train_data.view(-1, 1, 28, 28).float() +train_target = train_set.train_labels -def create_pair_set(used_classes, input, target): - u = [] +test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) +test_input = test_set.test_data.view(-1, 1, 28, 28).float() +test_target = test_set.test_labels - for i in used_classes: +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +used_MNIST_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9]) +# used_MNIST_classes = torch.tensor([ 0, 9, 7 ]) +# used_MNIST_classes = torch.tensor([ 3, 4, 7, 0 ]) + +###################################################################### + +# Returns a triplet of tensors (a, b, c), where a and b contain each +# half of the samples, with a[i] and b[i] of same class for any i, and +# c is a 1d long tensor with the count of pairs per class used. + +def create_MNIST_pair_set(train = False): + ua, ub = [], [] + + if train: + input, target = train_input, train_target + else: + input, target = test_input, test_target + + for i in used_MNIST_classes: used_indices = torch.arange(input.size(0), device = target.device)\ .masked_select(target == i.item()) x = input[used_indices] x = x[torch.randperm(x.size(0))] - # Careful with odd numbers of samples in a class - x = x[0:2 * (x.size(0) // 2)].reshape(-1, 2, 28, 28) - u.append(x) + hs = x.size(0)//2 + ua.append(x.narrow(0, 0, hs)) + ub.append(x.narrow(0, hs, hs)) - x = torch.cat(u, 0) - x = x[torch.randperm(x.size(0))] - c = torch.tensor([x.size(0) for x in u]) + a = torch.cat(ua, 0) + b = torch.cat(ub, 0) + perm = torch.randperm(a.size(0)) + a = a[perm].contiguous() + b = b[perm].contiguous() + c = torch.tensor([x.size(0) for x in ua]) - return x.narrow(1, 0, 1).contiguous(), x.narrow(1, 1, 1).contiguous(), c + return a, b, c ###################################################################### @@ -45,6 +66,7 @@ class Net(nn.Module): self.fc2 = nn.Linear(200, 1) def forward(self, a, b): + # Make the two images a single two-channel image x = torch.cat((a, b), 1) x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3)) x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2)) @@ -55,28 +77,12 @@ class Net(nn.Module): ###################################################################### -train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) -train_input = train_set.train_data.view(-1, 1, 28, 28).float() -train_target = train_set.train_labels - -test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) -test_input = test_set.test_data.view(-1, 1, 28, 28).float() -test_target = test_set.test_labels - -mu, std = train_input.mean(), train_input.std() -train_input.sub_(mu).div_(std) -test_input.sub_(mu).div_(std) - -###################################################################### - -# The information bound is the log of the number of classes in there - -# used_classes = torch.tensor([ 0, 1, 3, 5, 6, 7, 8, 9]) -used_classes = torch.tensor([ 3, 4, 7, 0 ]) - nb_epochs, batch_size = 50, 100 model = Net() + +print('nb_parameters %d' % sum(x.numel() for x in model.parameters())) + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) if torch.cuda.is_available(): @@ -86,8 +92,9 @@ if torch.cuda.is_available(): for e in range(nb_epochs): - input_a, input_b, count = create_pair_set(used_classes, train_input, train_target) + input_a, input_b, count = create_MNIST_pair_set(train = True) + # The information bound is the entropy of the class distribution class_proba = count.float() class_proba /= class_proba.sum() class_entropy = - (class_proba.log() * class_proba).sum().item() @@ -108,13 +115,13 @@ for e in range(nb_epochs): acc_mi /= (input_a.size(0) // batch_size) - print('%d %.04f %.04f'%(e, acc_mi / math.log(2), class_entropy / math.log(2))) + print('%d %.04f %.04f' % (e, acc_mi / math.log(2), class_entropy / math.log(2))) sys.stdout.flush() ###################################################################### -input_a, input_b, count = create_pair_set(used_classes, test_input, test_target) +input_a, input_b, count = create_MNIST_pair_set(train = False) for e in range(nb_epochs): class_proba = count.float() @@ -128,8 +135,8 @@ for e in range(nb_epochs): for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), input_b.split(batch_size), input_br.split(batch_size)): - loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()) - acc_mi -= loss.item() + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() acc_mi /= (input_a.size(0) // batch_size)