3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import argparse, math, sys
9 from copy import deepcopy
11 import torch, torchvision
14 import torch.nn.functional as F
16 ######################################################################
18 if torch.cuda.is_available():
19 torch.backends.cudnn.benchmark = True
20 device = torch.device("cuda")
22 device = torch.device("cpu")
24 ######################################################################
26 parser = argparse.ArgumentParser(
27 description="""An implementation of a Mutual Information estimator with a deep model
29 Three different toy data-sets are implemented, each consists of
30 pairs of samples, that may be from different spaces:
32 (1) Two MNIST images of same class. The "true" MI is the log of the
33 number of used MNIST classes.
35 (2) One MNIST image and a pair of real numbers whose difference is
36 the class of the image. The "true" MI is the log of the number of
39 (3) Two 1d sequences, the first with a single peak, the second with
40 two peaks, and the height of the peak in the first is the
41 difference of timing of the peaks in the second. The "true" MI is
42 the log of the number of possible peak heights.""",
43 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
50 help="What data: image_pair, image_values_pair, sequence_pair",
54 "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
60 default="0, 1, 3, 5, 6, 7, 8, 9",
61 help="What MNIST classes to use",
65 "--nb_classes", type=int, default=2, help="How many classes for sequences"
68 parser.add_argument("--nb_epochs", type=int, default=50, help="How many epochs")
70 parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
72 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
77 help="Should the pair components be independent",
80 ######################################################################
82 args = parser.parse_args()
85 torch.manual_seed(args.seed)
87 used_MNIST_classes = torch.tensor(eval("[" + args.mnist_classes + "]"), device=device)
89 ######################################################################
94 for k in range(target.max() + 1):
95 n = (target == k).sum().item()
98 probas = torch.tensor(probas).float()
99 probas /= probas.sum()
100 return -(probas * probas.log()).sum().item()
103 ######################################################################
105 train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
106 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
107 train_target = train_set.train_labels.to(device)
109 test_set = torchvision.datasets.MNIST("./data/mnist/", train=False, download=True)
110 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
111 test_target = test_set.test_labels.to(device)
113 mu, std = train_input.mean(), train_input.std()
114 train_input.sub_(mu).div_(std)
115 test_input.sub_(mu).div_(std)
117 ######################################################################
119 # Returns a triplet of tensors (a, b, c), where a and b contain each
120 # half of the samples, with a[i] and b[i] of same class for any i, and
121 # c is a 1d long tensor real classes
124 def create_image_pairs(train=False):
125 ua, ub, uc = [], [], []
128 input, target = train_input, train_target
130 input, target = test_input, test_target
132 for i in used_MNIST_classes:
133 used_indices = torch.arange(input.size(0), device=target.device).masked_select(
136 x = input[used_indices]
137 x = x[torch.randperm(x.size(0))]
139 ua.append(x.narrow(0, 0, hs))
140 ub.append(x.narrow(0, hs, hs))
141 uc.append(target[used_indices])
146 perm = torch.randperm(a.size(0))
147 a = a[perm].contiguous()
150 perm = torch.randperm(a.size(0))
151 b = b[perm].contiguous()
156 ######################################################################
158 # Returns a triplet a, b, c where a are the standard MNIST images, c
159 # the classes, and b is a Nx2 tensor, with for every n:
161 # b[n, 0] ~ Uniform(0, 10)
162 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
165 def create_image_values_pairs(train=False):
169 input, target = train_input, train_target
171 input, target = test_input, test_target
174 used_MNIST_classes.max() + 1, dtype=torch.uint8, device=target.device
176 m[used_MNIST_classes] = 1
178 used_indices = torch.arange(input.size(0), device=target.device).masked_select(m)
180 input = input[used_indices].contiguous()
181 target = target[used_indices].contiguous()
186 b = a.new(a.size(0), 2)
187 b[:, 0].uniform_(0.0, 10.0)
188 b[:, 1].uniform_(0.0, 0.5)
193 + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
196 b[:, 1] += b[:, 0] + target.float()
201 ######################################################################
206 def create_sequences_pairs(train=False):
207 nb, length = 10000, 1024
210 ha = torch.randint(args.nb_classes, (nb,), device=device) + 1
212 hb = torch.randint(args.nb_classes, (nb,), device=device)
216 pos = torch.empty(nb, device=device).uniform_(0.0, 0.9)
217 a = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
218 a = a - pos.view(nb, 1)
219 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
220 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
221 noise = a.new(a.size()).normal_(0, noise_level)
224 pos = torch.empty(nb, device=device).uniform_(0.0, 0.5)
225 b1 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
226 b1 = b1 - pos.view(nb, 1)
227 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
228 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
229 # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
230 b2 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
231 b2 = b2 - pos.view(nb, 1)
232 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
235 noise = b.new(b.size()).normal_(0, noise_level)
241 ######################################################################
244 class NetForImagePair(nn.Module):
247 self.features_a = nn.Sequential(
248 nn.Conv2d(1, 16, kernel_size=5),
251 nn.Conv2d(16, 32, kernel_size=5),
256 self.features_b = nn.Sequential(
257 nn.Conv2d(1, 16, kernel_size=5),
260 nn.Conv2d(16, 32, kernel_size=5),
265 self.fully_connected = nn.Sequential(
266 nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
269 def forward(self, a, b):
270 a = self.features_a(a).view(a.size(0), -1)
271 b = self.features_b(b).view(b.size(0), -1)
272 x = torch.cat((a, b), 1)
273 return self.fully_connected(x)
276 ######################################################################
279 class NetForImageValuesPair(nn.Module):
282 self.features_a = nn.Sequential(
283 nn.Conv2d(1, 16, kernel_size=5),
286 nn.Conv2d(16, 32, kernel_size=5),
291 self.features_b = nn.Sequential(
300 self.fully_connected = nn.Sequential(
301 nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
304 def forward(self, a, b):
305 a = self.features_a(a).view(a.size(0), -1)
306 b = self.features_b(b).view(b.size(0), -1)
307 x = torch.cat((a, b), 1)
308 return self.fully_connected(x)
311 ######################################################################
314 class NetForSequencePair(nn.Module):
315 def feature_model(self):
318 return nn.Sequential(
319 nn.Conv1d(1, self.nc, kernel_size=kernel_size),
320 nn.AvgPool1d(pooling_size),
322 nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
323 nn.AvgPool1d(pooling_size),
325 nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
326 nn.AvgPool1d(pooling_size),
328 nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
329 nn.AvgPool1d(pooling_size),
339 self.features_a = self.feature_model()
340 self.features_b = self.feature_model()
342 self.fully_connected = nn.Sequential(
343 nn.Linear(2 * self.nc, self.nh), nn.ReLU(), nn.Linear(self.nh, 1)
346 def forward(self, a, b):
347 a = a.view(a.size(0), 1, a.size(1))
348 a = self.features_a(a)
349 a = F.avg_pool1d(a, a.size(2))
351 b = b.view(b.size(0), 1, b.size(1))
352 b = self.features_b(b)
353 b = F.avg_pool1d(b, b.size(2))
355 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
356 return self.fully_connected(x)
359 ######################################################################
361 if args.data == "image_pair":
362 create_pairs = create_image_pairs
363 model = NetForImagePair()
365 elif args.data == "image_values_pair":
366 create_pairs = create_image_values_pairs
367 model = NetForImageValuesPair()
369 elif args.data == "sequence_pair":
370 create_pairs = create_sequences_pairs
371 model = NetForSequencePair()
373 ######################
375 a, b, c = create_pairs()
377 file = open(f"train_{k:02d}.dat", "w")
378 for i in range(a.size(1)):
379 file.write(f"{a[k, i]:f} {b[k,i]:f}\n")
381 ######################
384 raise Exception("Unknown data " + args.data)
386 ######################################################################
389 print(f"nb_parameters {sum(x.numel() for x in model.parameters())}")
393 input_a, input_b, classes = create_pairs(train=True)
395 for e in range(args.nb_epochs):
396 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
398 input_br = input_b[torch.randperm(input_b.size(0))]
402 for batch_a, batch_b, batch_br in zip(
403 input_a.split(args.batch_size),
404 input_b.split(args.batch_size),
405 input_br.split(args.batch_size),
408 model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
412 optimizer.zero_grad()
416 acc_mi /= input_a.size(0) // args.batch_size
418 print(f"{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
422 ######################################################################
425 input_a, input_b, classes = create_pairs(train=False)
427 input_br = input_b[torch.randperm(input_b.size(0))]
431 for batch_a, batch_b, batch_br in zip(
432 input_a.split(args.batch_size),
433 input_b.split(args.batch_size),
434 input_br.split(args.batch_size),
436 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
439 acc_mi /= input_a.size(0) // args.batch_size
441 print(f"test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
443 ######################################################################