3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation. #
8 # This program is distributed in the hope that it will be useful, but #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
11 # General Public License for more details. #
13 # You should have received a copy of the GNU General Public License #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
16 # Written by Francois Fleuret, (C) Idiap Research Institute #
18 # Contact <francois.fleuret@idiap.ch> for comments & bug reports #
19 #########################################################################
21 import argparse, math, sys
22 from copy import deepcopy
24 import torch, torchvision
27 import torch.nn.functional as F
29 ######################################################################
31 if torch.cuda.is_available():
32 torch.backends.cudnn.benchmark = True
33 device = torch.device('cuda')
35 device = torch.device('cpu')
37 ######################################################################
39 parser = argparse.ArgumentParser(
40 description = '''An implementation of a Mutual Information estimator with a deep model
42 Three different toy data-sets are implemented, each consists of
43 pairs of samples, that may be from different spaces:
45 (1) Two MNIST images of same class. The "true" MI is the log of the
46 number of used MNIST classes.
48 (2) One MNIST image and a pair of real numbers whose difference is
49 the class of the image. The "true" MI is the log of the number of
52 (3) Two 1d sequences, the first with a single peak, the second with
53 two peaks, and the height of the peak in the first is the
54 difference of timing of the peaks in the second. The "true" MI is
55 the log of the number of possible peak heights.''',
57 formatter_class = argparse.ArgumentDefaultsHelpFormatter
60 parser.add_argument('--data',
61 type = str, default = 'image_pair',
62 help = 'What data: image_pair, image_values_pair, sequence_pair')
64 parser.add_argument('--seed',
65 type = int, default = 0,
66 help = 'Random seed (default 0, < 0 is no seeding)')
68 parser.add_argument('--mnist_classes',
69 type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
70 help = 'What MNIST classes to use')
72 parser.add_argument('--nb_classes',
73 type = int, default = 2,
74 help = 'How many classes for sequences')
76 parser.add_argument('--nb_epochs',
77 type = int, default = 50,
78 help = 'How many epochs')
80 parser.add_argument('--batch_size',
81 type = int, default = 100,
84 parser.add_argument('--learning_rate',
85 type = float, default = 1e-3,
88 parser.add_argument('--independent', action = 'store_true',
89 help = 'Should the pair components be independent')
91 ######################################################################
93 args = parser.parse_args()
96 torch.manual_seed(args.seed)
98 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
100 ######################################################################
104 for k in range(target.max() + 1):
105 n = (target == k).sum().item()
106 if n > 0: probas.append(n)
107 probas = torch.tensor(probas).float()
108 probas /= probas.sum()
109 return - (probas * probas.log()).sum().item()
111 ######################################################################
113 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
114 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
115 train_target = train_set.train_labels.to(device)
117 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
118 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
119 test_target = test_set.test_labels.to(device)
121 mu, std = train_input.mean(), train_input.std()
122 train_input.sub_(mu).div_(std)
123 test_input.sub_(mu).div_(std)
125 ######################################################################
127 # Returns a triplet of tensors (a, b, c), where a and b contain each
128 # half of the samples, with a[i] and b[i] of same class for any i, and
129 # c is a 1d long tensor real classes
131 def create_image_pairs(train = False):
132 ua, ub, uc = [], [], []
135 input, target = train_input, train_target
137 input, target = test_input, test_target
139 for i in used_MNIST_classes:
140 used_indices = torch.arange(input.size(0), device = target.device)\
141 .masked_select(target == i.item())
142 x = input[used_indices]
143 x = x[torch.randperm(x.size(0))]
145 ua.append(x.narrow(0, 0, hs))
146 ub.append(x.narrow(0, hs, hs))
147 uc.append(target[used_indices])
152 perm = torch.randperm(a.size(0))
153 a = a[perm].contiguous()
156 perm = torch.randperm(a.size(0))
157 b = b[perm].contiguous()
161 ######################################################################
163 # Returns a triplet a, b, c where a are the standard MNIST images, c
164 # the classes, and b is a Nx2 tensor, with for every n:
166 # b[n, 0] ~ Uniform(0, 10)
167 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
169 def create_image_values_pairs(train = False):
173 input, target = train_input, train_target
175 input, target = test_input, test_target
177 m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
178 m[used_MNIST_classes] = 1
180 used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
182 input = input[used_indices].contiguous()
183 target = target[used_indices].contiguous()
188 b = a.new(a.size(0), 2)
189 b[:, 0].uniform_(0.0, 10.0)
190 b[:, 1].uniform_(0.0, 0.5)
193 b[:, 1] += b[:, 0] + \
194 used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
196 b[:, 1] += b[:, 0] + target.float()
200 ######################################################################
204 def create_sequences_pairs(train = False):
205 nb, length = 10000, 1024
208 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
210 hb = torch.randint(args.nb_classes, (nb, ), device = device)
214 pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
215 a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
216 a = a - pos.view(nb, 1)
217 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
218 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
219 noise = a.new(a.size()).normal_(0, noise_level)
222 pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
223 b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
224 b1 = b1 - pos.view(nb, 1)
225 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
226 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
227 # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
228 b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
229 b2 = b2 - pos.view(nb, 1)
230 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
233 noise = b.new(b.size()).normal_(0, noise_level)
238 ######################################################################
240 class NetForImagePair(nn.Module):
242 super(NetForImagePair, self).__init__()
243 self.features_a = nn.Sequential(
244 nn.Conv2d(1, 16, kernel_size = 5),
245 nn.MaxPool2d(3), nn.ReLU(),
246 nn.Conv2d(16, 32, kernel_size = 5),
247 nn.MaxPool2d(2), nn.ReLU(),
250 self.features_b = nn.Sequential(
251 nn.Conv2d(1, 16, kernel_size = 5),
252 nn.MaxPool2d(3), nn.ReLU(),
253 nn.Conv2d(16, 32, kernel_size = 5),
254 nn.MaxPool2d(2), nn.ReLU(),
257 self.fully_connected = nn.Sequential(
263 def forward(self, a, b):
264 a = self.features_a(a).view(a.size(0), -1)
265 b = self.features_b(b).view(b.size(0), -1)
266 x = torch.cat((a, b), 1)
267 return self.fully_connected(x)
269 ######################################################################
271 class NetForImageValuesPair(nn.Module):
273 super(NetForImageValuesPair, self).__init__()
274 self.features_a = nn.Sequential(
275 nn.Conv2d(1, 16, kernel_size = 5),
276 nn.MaxPool2d(3), nn.ReLU(),
277 nn.Conv2d(16, 32, kernel_size = 5),
278 nn.MaxPool2d(2), nn.ReLU(),
281 self.features_b = nn.Sequential(
282 nn.Linear(2, 32), nn.ReLU(),
283 nn.Linear(32, 32), nn.ReLU(),
284 nn.Linear(32, 128), nn.ReLU(),
287 self.fully_connected = nn.Sequential(
293 def forward(self, a, b):
294 a = self.features_a(a).view(a.size(0), -1)
295 b = self.features_b(b).view(b.size(0), -1)
296 x = torch.cat((a, b), 1)
297 return self.fully_connected(x)
299 ######################################################################
301 class NetForSequencePair(nn.Module):
303 def feature_model(self):
306 return nn.Sequential(
307 nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
308 nn.AvgPool1d(pooling_size),
310 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
311 nn.AvgPool1d(pooling_size),
313 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
314 nn.AvgPool1d(pooling_size),
316 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
317 nn.AvgPool1d(pooling_size),
322 super(NetForSequencePair, self).__init__()
327 self.features_a = self.feature_model()
328 self.features_b = self.feature_model()
330 self.fully_connected = nn.Sequential(
331 nn.Linear(2 * self.nc, self.nh),
333 nn.Linear(self.nh, 1)
336 def forward(self, a, b):
337 a = a.view(a.size(0), 1, a.size(1))
338 a = self.features_a(a)
339 a = F.avg_pool1d(a, a.size(2))
341 b = b.view(b.size(0), 1, b.size(1))
342 b = self.features_b(b)
343 b = F.avg_pool1d(b, b.size(2))
345 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
346 return self.fully_connected(x)
348 ######################################################################
350 if args.data == 'image_pair':
351 create_pairs = create_image_pairs
352 model = NetForImagePair()
354 elif args.data == 'image_values_pair':
355 create_pairs = create_image_values_pairs
356 model = NetForImageValuesPair()
358 elif args.data == 'sequence_pair':
359 create_pairs = create_sequences_pairs
360 model = NetForSequencePair()
362 ######################
364 a, b, c = create_pairs()
366 file = open(f'train_{k:02d}.dat', 'w')
367 for i in range(a.size(1)):
368 file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
370 ######################
373 raise Exception('Unknown data ' + args.data)
375 ######################################################################
378 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
382 input_a, input_b, classes = create_pairs(train = True)
384 for e in range(args.nb_epochs):
386 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
388 input_br = input_b[torch.randperm(input_b.size(0))]
392 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
393 input_b.split(args.batch_size),
394 input_br.split(args.batch_size)):
395 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
398 optimizer.zero_grad()
402 acc_mi /= (input_a.size(0) // args.batch_size)
404 print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
408 ######################################################################
411 input_a, input_b, classes = create_pairs(train = False)
413 input_br = input_b[torch.randperm(input_b.size(0))]
417 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
418 input_b.split(args.batch_size),
419 input_br.split(args.batch_size)):
420 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
423 acc_mi /= (input_a.size(0) // args.batch_size)
425 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
427 ######################################################################