X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=mine_mnist.py;h=389544b780f7b5d051460d914a25285f4251cc5d;hp=7845d813ae20a86ed14f7fb5dd29970841c9397f;hb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0;hpb=4817708db50a18242ade3ba88971dd4ef0a73004 diff --git a/mine_mnist.py b/mine_mnist.py index 7845d81..389544b 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -1,14 +1,17 @@ #!/usr/bin/env python import argparse, math, sys +from copy import deepcopy import torch, torchvision from torch import nn +import torch.nn.functional as F ###################################################################### if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True device = torch.device('cuda') else: device = torch.device('cpu') @@ -32,6 +35,21 @@ parser.add_argument('--mnist_classes', type = str, default = '0, 1, 3, 5, 6, 7, 8, 9', help = 'What MNIST classes to use') +parser.add_argument('--nb_classes', + type = int, default = 2, + help = 'How many classes for sequences') + +parser.add_argument('--nb_epochs', + type = int, default = 50, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 100, + help = 'Batch size') + +parser.add_argument('--independent', action = 'store_true', + help = 'Should the pair components be independent') + ###################################################################### def entropy(target): @@ -43,6 +61,12 @@ def entropy(target): probas /= probas.sum() return - (probas * probas.log()).sum().item() +def robust_log_mean_exp(x): + # a = x.max() + # return (x-a).exp().mean().log() + a + # a = x.max() + return x.exp().mean().log() + ###################################################################### args = parser.parse_args() @@ -73,7 +97,7 @@ test_input.sub_(mu).div_(std) # c is a 1d long tensor real classes def create_image_pairs(train = False): - ua, ub = [], [] + ua, ub, uc = [], [], [] if train: input, target = train_input, train_target @@ -95,6 +119,9 @@ def create_image_pairs(train = False): c = torch.cat(uc, 0) perm = torch.randperm(a.size(0)) a = a[perm].contiguous() + + if args.independent: + perm = torch.randperm(a.size(0)) b = b[perm].contiguous() return a, b, c @@ -127,14 +154,56 @@ def create_image_values_pairs(train = False): c = target b = a.new(a.size(0), 2) - b[:, 0].uniform_(10) - b[:, 1].uniform_(0.5) - b[:, 1] += b[:, 0] + target.float() + b[:, 0].uniform_(0.0, 10.0) + b[:, 1].uniform_(0.0, 0.5) + + if args.independent: + b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())] + else: + b[:, 1] += b[:, 0] + target.float() return a, b, c ###################################################################### +def create_sequences_pairs(train = False): + nb, length = 10000, 1024 + noise_level = 2e-2 + + ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1 + if args.independent: + hb = torch.randint(args.nb_classes, (nb, ), device = device) + else: + hb = ha + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) + a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + a = a - pos.view(nb, 1) + a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1) + a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes) + noise = a.new(a.size()).normal_(0, noise_level) + a = a + noise + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.5) + b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b1 = b1 - pos.view(nb, 1) + b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25 + pos = pos + hb.float() / (args.nb_classes + 1) * 0.5 + b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b2 = b2 - pos.view(nb, 1) + b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25 + + b = b1 + b2 + noise = b.new(b.size()).normal_(0, noise_level) + b = b + noise + + # a = (a - a.mean()) / a.std() + # b = (b - b.mean()) / b.std() + + return a, b, ha + +###################################################################### + class NetForImagePair(nn.Module): def __init__(self): super(NetForImagePair, self).__init__() @@ -196,26 +265,83 @@ class NetForImageValuesPair(nn.Module): ###################################################################### +class NetForSequencePair(nn.Module): + + def feature_model(self): + kernel_size = 11 + pooling_size = 4 + return nn.Sequential( + nn.Conv1d( 1, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + ) + + def __init__(self): + super(NetForSequencePair, self).__init__() + + self.nc = 32 + self.nh = 256 + + self.features_a = self.feature_model() + self.features_b = self.feature_model() + + self.fully_connected = nn.Sequential( + nn.Linear(2 * self.nc, self.nh), + nn.ReLU(), + nn.Linear(self.nh, 1) + ) + + def forward(self, a, b): + a = a.view(a.size(0), 1, a.size(1)) + a = self.features_a(a) + a = F.avg_pool1d(a, a.size(2)) + + b = b.view(b.size(0), 1, b.size(1)) + b = self.features_b(b) + b = F.avg_pool1d(b, b.size(2)) + + x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1) + return self.fully_connected(x) + +###################################################################### + if args.data == 'image_pair': create_pairs = create_image_pairs model = NetForImagePair() elif args.data == 'image_values_pair': create_pairs = create_image_values_pairs model = NetForImageValuesPair() +elif args.data == 'sequence_pair': + create_pairs = create_sequences_pairs + model = NetForSequencePair() + ###################################################################### + a, b, c = create_pairs() + for k in range(10): + file = open(f'train_{k:02d}.dat', 'w') + for i in range(a.size(1)): + file.write(f'{a[k, i]:f} {b[k,i]:f}\n') + file.close() + # exit(0) + ###################################################################### else: raise Exception('Unknown data ' + args.data) ###################################################################### -nb_epochs, batch_size = 50, 100 - print('nb_parameters %d' % sum(x.numel() for x in model.parameters())) -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) - model.to(device) -for e in range(nb_epochs): +for e in range(args.nb_epochs): input_a, input_b, classes = create_pairs(train = True) @@ -223,19 +349,21 @@ for e in range(nb_epochs): acc_mi = 0.0 - for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), - input_b.split(batch_size), - input_br.split(batch_size)): + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) + + for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() - loss = - mi acc_mi += mi.item() + loss = - mi optimizer.zero_grad() loss.backward() optimizer.step() - acc_mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // args.batch_size) - print('%d %.04f %.04f' % (e, acc_mi / math.log(2), entropy(classes) / math.log(2))) + print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2))) sys.stdout.flush() @@ -243,18 +371,17 @@ for e in range(nb_epochs): input_a, input_b, classes = create_pairs(train = False) -for e in range(nb_epochs): - input_br = input_b[torch.randperm(input_b.size(0))] +input_br = input_b[torch.randperm(input_b.size(0))] - acc_mi = 0.0 +acc_mi = 0.0 - for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), - input_b.split(batch_size), - input_br.split(batch_size)): - mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() - acc_mi += mi.item() +for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() - acc_mi /= (input_a.size(0) // batch_size) +acc_mi /= (input_a.size(0) // args.batch_size) print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))