Update
[pytorch] / mi_estimator.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 import argparse, math, sys
21 from copy import deepcopy
22
23 import torch, torchvision
24
25 from torch import nn
26 import torch.nn.functional as F
27
28 ######################################################################
29
30 if torch.cuda.is_available():
31     torch.backends.cudnn.benchmark = True
32     device = torch.device('cuda')
33 else:
34     device = torch.device('cpu')
35
36 ######################################################################
37
38 parser = argparse.ArgumentParser(
39     description = '''An implementation of a Mutual Information estimator with a deep model
40
41 Three different toy data-sets are implemented:
42
43  (1) Two MNIST images of same class. The "true" MI is the log of the
44      number of used MNIST classes.
45
46  (2) One MNIST image and a pair of real numbers whose difference is
47      the class of the image. The "true" MI is the log of the number of
48      used MNIST classes.
49
50  (3) Two 1d sequences, the first with a single peak, the second with
51      two peaks, and the height of the peak in the first is the
52      difference of timing of the peaks in the second. The "true" MI is
53      the log of the number of possible peak heights.''',
54
55     formatter_class = argparse.ArgumentDefaultsHelpFormatter
56 )
57
58 parser.add_argument('--data',
59                     type = str, default = 'image_pair',
60                     help = 'What data: image_pair, image_values_pair, sequence_pair')
61
62 parser.add_argument('--seed',
63                     type = int, default = 0,
64                     help = 'Random seed (default 0, < 0 is no seeding)')
65
66 parser.add_argument('--mnist_classes',
67                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
68                     help = 'What MNIST classes to use')
69
70 parser.add_argument('--nb_classes',
71                     type = int, default = 2,
72                     help = 'How many classes for sequences')
73
74 parser.add_argument('--nb_epochs',
75                     type = int, default = 50,
76                     help = 'How many epochs')
77
78 parser.add_argument('--batch_size',
79                     type = int, default = 100,
80                     help = 'Batch size')
81
82 parser.add_argument('--learning_rate',
83                     type = float, default = 1e-3,
84                     help = 'Batch size')
85
86 parser.add_argument('--independent', action = 'store_true',
87                     help = 'Should the pair components be independent')
88
89 ######################################################################
90
91 args = parser.parse_args()
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
97
98 ######################################################################
99
100 def entropy(target):
101     probas = []
102     for k in range(target.max() + 1):
103         n = (target == k).sum().item()
104         if n > 0: probas.append(n)
105     probas = torch.tensor(probas).float()
106     probas /= probas.sum()
107     return - (probas * probas.log()).sum().item()
108
109 ######################################################################
110
111 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
112 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
113 train_target = train_set.train_labels.to(device)
114
115 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
116 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
117 test_target = test_set.test_labels.to(device)
118
119 mu, std = train_input.mean(), train_input.std()
120 train_input.sub_(mu).div_(std)
121 test_input.sub_(mu).div_(std)
122
123 ######################################################################
124
125 # Returns a triplet of tensors (a, b, c), where a and b contain each
126 # half of the samples, with a[i] and b[i] of same class for any i, and
127 # c is a 1d long tensor real classes
128
129 def create_image_pairs(train = False):
130     ua, ub, uc = [], [], []
131
132     if train:
133         input, target = train_input, train_target
134     else:
135         input, target = test_input, test_target
136
137     for i in used_MNIST_classes:
138         used_indices = torch.arange(input.size(0), device = target.device)\
139                             .masked_select(target == i.item())
140         x = input[used_indices]
141         x = x[torch.randperm(x.size(0))]
142         hs = x.size(0)//2
143         ua.append(x.narrow(0, 0, hs))
144         ub.append(x.narrow(0, hs, hs))
145         uc.append(target[used_indices])
146
147     a = torch.cat(ua, 0)
148     b = torch.cat(ub, 0)
149     c = torch.cat(uc, 0)
150     perm = torch.randperm(a.size(0))
151     a = a[perm].contiguous()
152
153     if args.independent:
154         perm = torch.randperm(a.size(0))
155     b = b[perm].contiguous()
156
157     return a, b, c
158
159 ######################################################################
160
161 # Returns a triplet a, b, c where a are the standard MNIST images, c
162 # the classes, and b is a Nx2 tensor, with for every n:
163 #
164 #   b[n, 0] ~ Uniform(0, 10)
165 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
166
167 def create_image_values_pairs(train = False):
168     ua, ub = [], []
169
170     if train:
171         input, target = train_input, train_target
172     else:
173         input, target = test_input, test_target
174
175     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
176     m[used_MNIST_classes] = 1
177     m = m[target]
178     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
179
180     input = input[used_indices].contiguous()
181     target = target[used_indices].contiguous()
182
183     a = input
184     c = target
185
186     b = a.new(a.size(0), 2)
187     b[:, 0].uniform_(0.0, 10.0)
188     b[:, 1].uniform_(0.0, 0.5)
189
190     if args.independent:
191         b[:, 1] += b[:, 0] + \
192                    used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
193     else:
194         b[:, 1] += b[:, 0] + target.float()
195
196     return a, b, c
197
198 ######################################################################
199
200 def create_sequences_pairs(train = False):
201     nb, length = 10000, 1024
202     noise_level = 2e-2
203
204     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
205     if args.independent:
206         hb = torch.randint(args.nb_classes, (nb, ), device = device)
207     else:
208         hb = ha
209
210     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
211     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
212     a = a - pos.view(nb, 1)
213     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
214     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
215     noise = a.new(a.size()).normal_(0, noise_level)
216     a = a + noise
217
218     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
219     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
220     b1 = b1 - pos.view(nb, 1)
221     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
222     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
223     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
224     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
225     b2 = b2 - pos.view(nb, 1)
226     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
227
228     b = b1 + b2
229     noise = b.new(b.size()).normal_(0, noise_level)
230     b = b + noise
231
232     # a = (a - a.mean()) / a.std()
233     # b = (b - b.mean()) / b.std()
234
235     return a, b, ha
236
237 ######################################################################
238
239 class NetForImagePair(nn.Module):
240     def __init__(self):
241         super(NetForImagePair, self).__init__()
242         self.features_a = nn.Sequential(
243             nn.Conv2d(1, 16, kernel_size = 5),
244             nn.MaxPool2d(3), nn.ReLU(),
245             nn.Conv2d(16, 32, kernel_size = 5),
246             nn.MaxPool2d(2), nn.ReLU(),
247         )
248
249         self.features_b = nn.Sequential(
250             nn.Conv2d(1, 16, kernel_size = 5),
251             nn.MaxPool2d(3), nn.ReLU(),
252             nn.Conv2d(16, 32, kernel_size = 5),
253             nn.MaxPool2d(2), nn.ReLU(),
254         )
255
256         self.fully_connected = nn.Sequential(
257             nn.Linear(256, 200),
258             nn.ReLU(),
259             nn.Linear(200, 1)
260         )
261
262     def forward(self, a, b):
263         a = self.features_a(a).view(a.size(0), -1)
264         b = self.features_b(b).view(b.size(0), -1)
265         x = torch.cat((a, b), 1)
266         return self.fully_connected(x)
267
268 ######################################################################
269
270 class NetForImageValuesPair(nn.Module):
271     def __init__(self):
272         super(NetForImageValuesPair, self).__init__()
273         self.features_a = nn.Sequential(
274             nn.Conv2d(1, 16, kernel_size = 5),
275             nn.MaxPool2d(3), nn.ReLU(),
276             nn.Conv2d(16, 32, kernel_size = 5),
277             nn.MaxPool2d(2), nn.ReLU(),
278         )
279
280         self.features_b = nn.Sequential(
281             nn.Linear(2, 32), nn.ReLU(),
282             nn.Linear(32, 32), nn.ReLU(),
283             nn.Linear(32, 128), nn.ReLU(),
284         )
285
286         self.fully_connected = nn.Sequential(
287             nn.Linear(256, 200),
288             nn.ReLU(),
289             nn.Linear(200, 1)
290         )
291
292     def forward(self, a, b):
293         a = self.features_a(a).view(a.size(0), -1)
294         b = self.features_b(b).view(b.size(0), -1)
295         x = torch.cat((a, b), 1)
296         return self.fully_connected(x)
297
298 ######################################################################
299
300 class NetForSequencePair(nn.Module):
301
302     def feature_model(self):
303         kernel_size = 11
304         pooling_size = 4
305         return  nn.Sequential(
306             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
307             nn.AvgPool1d(pooling_size),
308             nn.LeakyReLU(),
309             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
310             nn.AvgPool1d(pooling_size),
311             nn.LeakyReLU(),
312             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
313             nn.AvgPool1d(pooling_size),
314             nn.LeakyReLU(),
315             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
316             nn.AvgPool1d(pooling_size),
317             nn.LeakyReLU(),
318         )
319
320     def __init__(self):
321         super(NetForSequencePair, self).__init__()
322
323         self.nc = 32
324         self.nh = 256
325
326         self.features_a = self.feature_model()
327         self.features_b = self.feature_model()
328
329         self.fully_connected = nn.Sequential(
330             nn.Linear(2 * self.nc, self.nh),
331             nn.ReLU(),
332             nn.Linear(self.nh, 1)
333         )
334
335     def forward(self, a, b):
336         a = a.view(a.size(0), 1, a.size(1))
337         a = self.features_a(a)
338         a = F.avg_pool1d(a, a.size(2))
339
340         b = b.view(b.size(0), 1, b.size(1))
341         b = self.features_b(b)
342         b = F.avg_pool1d(b, b.size(2))
343
344         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
345         return self.fully_connected(x)
346
347 ######################################################################
348
349 if args.data == 'image_pair':
350     create_pairs = create_image_pairs
351     model = NetForImagePair()
352
353 elif args.data == 'image_values_pair':
354     create_pairs = create_image_values_pairs
355     model = NetForImageValuesPair()
356
357 elif args.data == 'sequence_pair':
358     create_pairs = create_sequences_pairs
359     model = NetForSequencePair()
360
361     ######################
362     ## Save for figures
363     a, b, c = create_pairs()
364     for k in range(10):
365         file = open(f'train_{k:02d}.dat', 'w')
366         for i in range(a.size(1)):
367             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
368         file.close()
369     ######################
370
371 else:
372     raise Exception('Unknown data ' + args.data)
373
374 ######################################################################
375 # Train
376
377 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
378
379 model.to(device)
380
381 input_a, input_b, classes = create_pairs(train = True)
382
383 for e in range(args.nb_epochs):
384
385     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
386
387     input_br = input_b[torch.randperm(input_b.size(0))]
388
389     acc_mi = 0.0
390
391     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
392                                           input_b.split(args.batch_size),
393                                           input_br.split(args.batch_size)):
394         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
395         acc_mi += mi.item()
396         loss = - mi
397         optimizer.zero_grad()
398         loss.backward()
399         optimizer.step()
400
401     acc_mi /= (input_a.size(0) // args.batch_size)
402
403     print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
404
405     sys.stdout.flush()
406
407 ######################################################################
408 # Test
409
410 input_a, input_b, classes = create_pairs(train = False)
411
412 input_br = input_b[torch.randperm(input_b.size(0))]
413
414 acc_mi = 0.0
415
416 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
417                                       input_b.split(args.batch_size),
418                                       input_br.split(args.batch_size)):
419     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
420     acc_mi += mi.item()
421
422 acc_mi /= (input_a.size(0) // args.batch_size)
423
424 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
425
426 ######################################################################