From 0e7d17bae1211d8019428ba4cd59e0af2a7ab074 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 14 Dec 2018 10:07:10 +0100 Subject: [PATCH] Update. --- mine_mnist.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 2 deletions(-) diff --git a/mine_mnist.py b/mine_mnist.py index 7845d81..c22d7fe 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -1,15 +1,18 @@ #!/usr/bin/env python import argparse, math, sys +from copy import deepcopy import torch, torchvision from torch import nn +import torch.nn.functional as F ###################################################################### if torch.cuda.is_available(): device = torch.device('cuda') + torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') @@ -73,7 +76,7 @@ test_input.sub_(mu).div_(std) # c is a 1d long tensor real classes def create_image_pairs(train = False): - ua, ub = [], [] + ua, ub, uc = [], [], [] if train: input, target = train_input, train_target @@ -135,6 +138,52 @@ def create_image_values_pairs(train = False): ###################################################################### +def create_sequences_pairs(train = False): + nb, length = 10000, 1024 + noise_level = 1e-2 + + nb_classes = 4 + ha = torch.randint(nb_classes, (nb, ), device = device) + 1 + # hb = torch.randint(nb_classes, (nb, ), device = device) + hb = ha + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) + a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + a = a - pos.view(nb, 1) + a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1) + a = a * ha.float().view(-1, 1).expand_as(a) / (1 + nb_classes) + noise = a.new(a.size()).normal_(0, noise_level) + a = a + noise + + pos = torch.empty(nb, device = device).uniform_(0.5) + b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b1 = b1 - pos.view(nb, 1) + b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) + pos = pos + hb.float() / (nb_classes + 1) * 0.5 + b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b2 = b2 - pos.view(nb, 1) + b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) + + b = b1 + b2 + noise = b.new(b.size()).normal_(0, noise_level) + b = b + noise + + ###################################################################### + # for k in range(10): + # file = open(f'/tmp/dat{k:02d}', 'w') + # for i in range(a.size(1)): + # file.write(f'{a[k, i]:f} {b[k,i]:f}\n') + # file.close() + # exit(0) + ###################################################################### + + a = (a - a.mean()) / a.std() + b = (b - b.mean()) / b.std() + + return a, b, ha + +###################################################################### + class NetForImagePair(nn.Module): def __init__(self): super(NetForImagePair, self).__init__() @@ -196,12 +245,60 @@ class NetForImageValuesPair(nn.Module): ###################################################################### +class NetForSequencePair(nn.Module): + + def feature_model(self): + return nn.Sequential( + nn.Conv1d(1, self.nc, kernel_size = 5), + nn.MaxPool1d(2), nn.ReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = 5), + nn.MaxPool1d(2), nn.ReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = 5), + nn.MaxPool1d(2), nn.ReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = 5), + nn.MaxPool1d(2), nn.ReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = 5), + nn.MaxPool1d(2), nn.ReLU(), + ) + + def __init__(self): + super(NetForSequencePair, self).__init__() + + self.nc = 32 + self.nh = 256 + + self.features_a = self.feature_model() + self.features_b = self.feature_model() + + self.fully_connected = nn.Sequential( + nn.Linear(2 * self.nc, self.nh), + nn.ReLU(), + nn.Linear(self.nh, 1) + ) + + def forward(self, a, b): + a = a.view(a.size(0), 1, a.size(1)) + a = self.features_a(a) + a = F.avg_pool1d(a, a.size(2)) + + b = b.view(b.size(0), 1, b.size(1)) + b = self.features_b(b) + b = F.avg_pool1d(b, b.size(2)) + + x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1) + return self.fully_connected(x) + +###################################################################### + if args.data == 'image_pair': create_pairs = create_image_pairs model = NetForImagePair() elif args.data == 'image_values_pair': create_pairs = create_image_values_pairs model = NetForImageValuesPair() +elif args.data == 'sequence_pair': + create_pairs = create_sequences_pairs + model = NetForSequencePair() else: raise Exception('Unknown data ' + args.data) @@ -235,7 +332,7 @@ for e in range(nb_epochs): acc_mi /= (input_a.size(0) // batch_size) - print('%d %.04f %.04f' % (e, acc_mi / math.log(2), entropy(classes) / math.log(2))) + print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2))) sys.stdout.flush() -- 2.39.5