Update.
[pytorch.git] / mi_estimator.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import argparse, math, sys
9 from copy import deepcopy
10
11 import torch, torchvision
12
13 from torch import nn
14 import torch.nn.functional as F
15
16 ######################################################################
17
18 if torch.cuda.is_available():
19     torch.backends.cudnn.benchmark = True
20     device = torch.device("cuda")
21 else:
22     device = torch.device("cpu")
23
24 ######################################################################
25
26 parser = argparse.ArgumentParser(
27     description="""An implementation of a Mutual Information estimator with a deep model
28
29     Three different toy data-sets are implemented, each consists of
30     pairs of samples, that may be from different spaces:
31
32     (1) Two MNIST images of same class. The "true" MI is the log of the
33     number of used MNIST classes.
34
35     (2) One MNIST image and a pair of real numbers whose difference is
36     the class of the image. The "true" MI is the log of the number of
37     used MNIST classes.
38
39     (3) Two 1d sequences, the first with a single peak, the second with
40     two peaks, and the height of the peak in the first is the
41     difference of timing of the peaks in the second. The "true" MI is
42     the log of the number of possible peak heights.""",
43     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
44 )
45
46 parser.add_argument(
47     "--data",
48     type=str,
49     default="image_pair",
50     help="What data: image_pair, image_values_pair, sequence_pair",
51 )
52
53 parser.add_argument(
54     "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
55 )
56
57 parser.add_argument(
58     "--mnist_classes",
59     type=str,
60     default="0, 1, 3, 5, 6, 7, 8, 9",
61     help="What MNIST classes to use",
62 )
63
64 parser.add_argument(
65     "--nb_classes", type=int, default=2, help="How many classes for sequences"
66 )
67
68 parser.add_argument("--nb_epochs", type=int, default=50, help="How many epochs")
69
70 parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
71
72 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
73
74 parser.add_argument(
75     "--independent",
76     action="store_true",
77     help="Should the pair components be independent",
78 )
79
80 ######################################################################
81
82 args = parser.parse_args()
83
84 if args.seed >= 0:
85     torch.manual_seed(args.seed)
86
87 used_MNIST_classes = torch.tensor(eval("[" + args.mnist_classes + "]"), device=device)
88
89 ######################################################################
90
91
92 def entropy(target):
93     probas = []
94     for k in range(target.max() + 1):
95         n = (target == k).sum().item()
96         if n > 0:
97             probas.append(n)
98     probas = torch.tensor(probas).float()
99     probas /= probas.sum()
100     return -(probas * probas.log()).sum().item()
101
102
103 ######################################################################
104
105 train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
106 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
107 train_target = train_set.train_labels.to(device)
108
109 test_set = torchvision.datasets.MNIST("./data/mnist/", train=False, download=True)
110 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
111 test_target = test_set.test_labels.to(device)
112
113 mu, std = train_input.mean(), train_input.std()
114 train_input.sub_(mu).div_(std)
115 test_input.sub_(mu).div_(std)
116
117 ######################################################################
118
119 # Returns a triplet of tensors (a, b, c), where a and b contain each
120 # half of the samples, with a[i] and b[i] of same class for any i, and
121 # c is a 1d long tensor real classes
122
123
124 def create_image_pairs(train=False):
125     ua, ub, uc = [], [], []
126
127     if train:
128         input, target = train_input, train_target
129     else:
130         input, target = test_input, test_target
131
132     for i in used_MNIST_classes:
133         used_indices = torch.arange(input.size(0), device=target.device).masked_select(
134             target == i.item()
135         )
136         x = input[used_indices]
137         x = x[torch.randperm(x.size(0))]
138         hs = x.size(0) // 2
139         ua.append(x.narrow(0, 0, hs))
140         ub.append(x.narrow(0, hs, hs))
141         uc.append(target[used_indices])
142
143     a = torch.cat(ua, 0)
144     b = torch.cat(ub, 0)
145     c = torch.cat(uc, 0)
146     perm = torch.randperm(a.size(0))
147     a = a[perm].contiguous()
148
149     if args.independent:
150         perm = torch.randperm(a.size(0))
151     b = b[perm].contiguous()
152
153     return a, b, c
154
155
156 ######################################################################
157
158 # Returns a triplet a, b, c where a are the standard MNIST images, c
159 # the classes, and b is a Nx2 tensor, with for every n:
160 #
161 #   b[n, 0] ~ Uniform(0, 10)
162 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
163
164
165 def create_image_values_pairs(train=False):
166     ua, ub = [], []
167
168     if train:
169         input, target = train_input, train_target
170     else:
171         input, target = test_input, test_target
172
173     m = torch.zeros(
174         used_MNIST_classes.max() + 1, dtype=torch.uint8, device=target.device
175     )
176     m[used_MNIST_classes] = 1
177     m = m[target]
178     used_indices = torch.arange(input.size(0), device=target.device).masked_select(m)
179
180     input = input[used_indices].contiguous()
181     target = target[used_indices].contiguous()
182
183     a = input
184     c = target
185
186     b = a.new(a.size(0), 2)
187     b[:, 0].uniform_(0.0, 10.0)
188     b[:, 1].uniform_(0.0, 0.5)
189
190     if args.independent:
191         b[:, 1] += (
192             b[:, 0]
193             + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
194         )
195     else:
196         b[:, 1] += b[:, 0] + target.float()
197
198     return a, b, c
199
200
201 ######################################################################
202
203 #
204
205
206 def create_sequences_pairs(train=False):
207     nb, length = 10000, 1024
208     noise_level = 2e-2
209
210     ha = torch.randint(args.nb_classes, (nb,), device=device) + 1
211     if args.independent:
212         hb = torch.randint(args.nb_classes, (nb,), device=device)
213     else:
214         hb = ha
215
216     pos = torch.empty(nb, device=device).uniform_(0.0, 0.9)
217     a = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
218     a = a - pos.view(nb, 1)
219     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
220     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
221     noise = a.new(a.size()).normal_(0, noise_level)
222     a = a + noise
223
224     pos = torch.empty(nb, device=device).uniform_(0.0, 0.5)
225     b1 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
226     b1 = b1 - pos.view(nb, 1)
227     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
228     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
229     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
230     b2 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
231     b2 = b2 - pos.view(nb, 1)
232     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
233
234     b = b1 + b2
235     noise = b.new(b.size()).normal_(0, noise_level)
236     b = b + noise
237
238     return a, b, ha
239
240
241 ######################################################################
242
243
244 class NetForImagePair(nn.Module):
245     def __init__(self):
246         super().__init__()
247         self.features_a = nn.Sequential(
248             nn.Conv2d(1, 16, kernel_size=5),
249             nn.MaxPool2d(3),
250             nn.ReLU(),
251             nn.Conv2d(16, 32, kernel_size=5),
252             nn.MaxPool2d(2),
253             nn.ReLU(),
254         )
255
256         self.features_b = nn.Sequential(
257             nn.Conv2d(1, 16, kernel_size=5),
258             nn.MaxPool2d(3),
259             nn.ReLU(),
260             nn.Conv2d(16, 32, kernel_size=5),
261             nn.MaxPool2d(2),
262             nn.ReLU(),
263         )
264
265         self.fully_connected = nn.Sequential(
266             nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
267         )
268
269     def forward(self, a, b):
270         a = self.features_a(a).view(a.size(0), -1)
271         b = self.features_b(b).view(b.size(0), -1)
272         x = torch.cat((a, b), 1)
273         return self.fully_connected(x)
274
275
276 ######################################################################
277
278
279 class NetForImageValuesPair(nn.Module):
280     def __init__(self):
281         super().__init__()
282         self.features_a = nn.Sequential(
283             nn.Conv2d(1, 16, kernel_size=5),
284             nn.MaxPool2d(3),
285             nn.ReLU(),
286             nn.Conv2d(16, 32, kernel_size=5),
287             nn.MaxPool2d(2),
288             nn.ReLU(),
289         )
290
291         self.features_b = nn.Sequential(
292             nn.Linear(2, 32),
293             nn.ReLU(),
294             nn.Linear(32, 32),
295             nn.ReLU(),
296             nn.Linear(32, 128),
297             nn.ReLU(),
298         )
299
300         self.fully_connected = nn.Sequential(
301             nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
302         )
303
304     def forward(self, a, b):
305         a = self.features_a(a).view(a.size(0), -1)
306         b = self.features_b(b).view(b.size(0), -1)
307         x = torch.cat((a, b), 1)
308         return self.fully_connected(x)
309
310
311 ######################################################################
312
313
314 class NetForSequencePair(nn.Module):
315     def feature_model(self):
316         kernel_size = 11
317         pooling_size = 4
318         return nn.Sequential(
319             nn.Conv1d(1, self.nc, kernel_size=kernel_size),
320             nn.AvgPool1d(pooling_size),
321             nn.LeakyReLU(),
322             nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
323             nn.AvgPool1d(pooling_size),
324             nn.LeakyReLU(),
325             nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
326             nn.AvgPool1d(pooling_size),
327             nn.LeakyReLU(),
328             nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
329             nn.AvgPool1d(pooling_size),
330             nn.LeakyReLU(),
331         )
332
333     def __init__(self):
334         super().__init__()
335
336         self.nc = 32
337         self.nh = 256
338
339         self.features_a = self.feature_model()
340         self.features_b = self.feature_model()
341
342         self.fully_connected = nn.Sequential(
343             nn.Linear(2 * self.nc, self.nh), nn.ReLU(), nn.Linear(self.nh, 1)
344         )
345
346     def forward(self, a, b):
347         a = a.view(a.size(0), 1, a.size(1))
348         a = self.features_a(a)
349         a = F.avg_pool1d(a, a.size(2))
350
351         b = b.view(b.size(0), 1, b.size(1))
352         b = self.features_b(b)
353         b = F.avg_pool1d(b, b.size(2))
354
355         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
356         return self.fully_connected(x)
357
358
359 ######################################################################
360
361 if args.data == "image_pair":
362     create_pairs = create_image_pairs
363     model = NetForImagePair()
364
365 elif args.data == "image_values_pair":
366     create_pairs = create_image_values_pairs
367     model = NetForImageValuesPair()
368
369 elif args.data == "sequence_pair":
370     create_pairs = create_sequences_pairs
371     model = NetForSequencePair()
372
373     ######################
374     ## Save for figures
375     a, b, c = create_pairs()
376     for k in range(10):
377         file = open(f"train_{k:02d}.dat", "w")
378         for i in range(a.size(1)):
379             file.write(f"{a[k, i]:f} {b[k,i]:f}\n")
380         file.close()
381     ######################
382
383 else:
384     raise Exception("Unknown data " + args.data)
385
386 ######################################################################
387 # Train
388
389 print(f"nb_parameters {sum(x.numel() for x in model.parameters())}")
390
391 model.to(device)
392
393 input_a, input_b, classes = create_pairs(train=True)
394
395 for e in range(args.nb_epochs):
396     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
397
398     input_br = input_b[torch.randperm(input_b.size(0))]
399
400     acc_mi = 0.0
401
402     for batch_a, batch_b, batch_br in zip(
403         input_a.split(args.batch_size),
404         input_b.split(args.batch_size),
405         input_br.split(args.batch_size),
406     ):
407         mi = (
408             model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
409         )
410         acc_mi += mi.item()
411         loss = -mi
412         optimizer.zero_grad()
413         loss.backward()
414         optimizer.step()
415
416     acc_mi /= input_a.size(0) // args.batch_size
417
418     print(f"{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
419
420     sys.stdout.flush()
421
422 ######################################################################
423 # Test
424
425 input_a, input_b, classes = create_pairs(train=False)
426
427 input_br = input_b[torch.randperm(input_b.size(0))]
428
429 acc_mi = 0.0
430
431 for batch_a, batch_b, batch_br in zip(
432     input_a.split(args.batch_size),
433     input_b.split(args.batch_size),
434     input_br.split(args.batch_size),
435 ):
436     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
437     acc_mi += mi.item()
438
439 acc_mi /= input_a.size(0) // args.batch_size
440
441 print(f"test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
442
443 ######################################################################