3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation. #
8 # This program is distributed in the hope that it will be useful, but #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
11 # General Public License for more details. #
13 # You should have received a copy of the GNU General Public License #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
16 # Written by and Copyright (C) Francois Fleuret #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports #
18 #########################################################################
20 import argparse, math, sys
21 from copy import deepcopy
23 import torch, torchvision
26 import torch.nn.functional as F
28 ######################################################################
30 if torch.cuda.is_available():
31 torch.backends.cudnn.benchmark = True
32 device = torch.device('cuda')
34 device = torch.device('cpu')
36 ######################################################################
38 parser = argparse.ArgumentParser(
39 description = '''An implementation of a Mutual Information estimator with a deep model
41 Three different toy data-sets are implemented:
43 (1) Two MNIST images of same class. The "true" MI is the log of the
44 number of used MNIST classes.
46 (2) One MNIST image and a pair of real numbers whose difference is
47 the class of the image. The "true" MI is the log of the number of
50 (3) Two 1d sequences, the first with a single peak, the second with
51 two peaks, and the height of the peak in the first is the
52 difference of timing of the peaks in the second. The "true" MI is
53 the log of the number of possible peak heights.''',
55 formatter_class = argparse.ArgumentDefaultsHelpFormatter
58 parser.add_argument('--data',
59 type = str, default = 'image_pair',
60 help = 'What data: image_pair, image_values_pair, sequence_pair')
62 parser.add_argument('--seed',
63 type = int, default = 0,
64 help = 'Random seed (default 0, < 0 is no seeding)')
66 parser.add_argument('--mnist_classes',
67 type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
68 help = 'What MNIST classes to use')
70 parser.add_argument('--nb_classes',
71 type = int, default = 2,
72 help = 'How many classes for sequences')
74 parser.add_argument('--nb_epochs',
75 type = int, default = 50,
76 help = 'How many epochs')
78 parser.add_argument('--batch_size',
79 type = int, default = 100,
82 parser.add_argument('--learning_rate',
83 type = float, default = 1e-3,
86 parser.add_argument('--independent', action = 'store_true',
87 help = 'Should the pair components be independent')
89 ######################################################################
91 args = parser.parse_args()
94 torch.manual_seed(args.seed)
96 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
98 ######################################################################
102 for k in range(target.max() + 1):
103 n = (target == k).sum().item()
104 if n > 0: probas.append(n)
105 probas = torch.tensor(probas).float()
106 probas /= probas.sum()
107 return - (probas * probas.log()).sum().item()
109 ######################################################################
111 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
112 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
113 train_target = train_set.train_labels.to(device)
115 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
116 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
117 test_target = test_set.test_labels.to(device)
119 mu, std = train_input.mean(), train_input.std()
120 train_input.sub_(mu).div_(std)
121 test_input.sub_(mu).div_(std)
123 ######################################################################
125 # Returns a triplet of tensors (a, b, c), where a and b contain each
126 # half of the samples, with a[i] and b[i] of same class for any i, and
127 # c is a 1d long tensor real classes
129 def create_image_pairs(train = False):
130 ua, ub, uc = [], [], []
133 input, target = train_input, train_target
135 input, target = test_input, test_target
137 for i in used_MNIST_classes:
138 used_indices = torch.arange(input.size(0), device = target.device)\
139 .masked_select(target == i.item())
140 x = input[used_indices]
141 x = x[torch.randperm(x.size(0))]
143 ua.append(x.narrow(0, 0, hs))
144 ub.append(x.narrow(0, hs, hs))
145 uc.append(target[used_indices])
150 perm = torch.randperm(a.size(0))
151 a = a[perm].contiguous()
154 perm = torch.randperm(a.size(0))
155 b = b[perm].contiguous()
159 ######################################################################
161 # Returns a triplet a, b, c where a are the standard MNIST images, c
162 # the classes, and b is a Nx2 tensor, with for every n:
164 # b[n, 0] ~ Uniform(0, 10)
165 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
167 def create_image_values_pairs(train = False):
171 input, target = train_input, train_target
173 input, target = test_input, test_target
175 m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
176 m[used_MNIST_classes] = 1
178 used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
180 input = input[used_indices].contiguous()
181 target = target[used_indices].contiguous()
186 b = a.new(a.size(0), 2)
187 b[:, 0].uniform_(0.0, 10.0)
188 b[:, 1].uniform_(0.0, 0.5)
191 b[:, 1] += b[:, 0] + \
192 used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
194 b[:, 1] += b[:, 0] + target.float()
198 ######################################################################
200 def create_sequences_pairs(train = False):
201 nb, length = 10000, 1024
204 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
206 hb = torch.randint(args.nb_classes, (nb, ), device = device)
210 pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
211 a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
212 a = a - pos.view(nb, 1)
213 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
214 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
215 noise = a.new(a.size()).normal_(0, noise_level)
218 pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
219 b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
220 b1 = b1 - pos.view(nb, 1)
221 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
222 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
223 # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
224 b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
225 b2 = b2 - pos.view(nb, 1)
226 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
229 noise = b.new(b.size()).normal_(0, noise_level)
232 # a = (a - a.mean()) / a.std()
233 # b = (b - b.mean()) / b.std()
237 ######################################################################
239 class NetForImagePair(nn.Module):
241 super(NetForImagePair, self).__init__()
242 self.features_a = nn.Sequential(
243 nn.Conv2d(1, 16, kernel_size = 5),
244 nn.MaxPool2d(3), nn.ReLU(),
245 nn.Conv2d(16, 32, kernel_size = 5),
246 nn.MaxPool2d(2), nn.ReLU(),
249 self.features_b = nn.Sequential(
250 nn.Conv2d(1, 16, kernel_size = 5),
251 nn.MaxPool2d(3), nn.ReLU(),
252 nn.Conv2d(16, 32, kernel_size = 5),
253 nn.MaxPool2d(2), nn.ReLU(),
256 self.fully_connected = nn.Sequential(
262 def forward(self, a, b):
263 a = self.features_a(a).view(a.size(0), -1)
264 b = self.features_b(b).view(b.size(0), -1)
265 x = torch.cat((a, b), 1)
266 return self.fully_connected(x)
268 ######################################################################
270 class NetForImageValuesPair(nn.Module):
272 super(NetForImageValuesPair, self).__init__()
273 self.features_a = nn.Sequential(
274 nn.Conv2d(1, 16, kernel_size = 5),
275 nn.MaxPool2d(3), nn.ReLU(),
276 nn.Conv2d(16, 32, kernel_size = 5),
277 nn.MaxPool2d(2), nn.ReLU(),
280 self.features_b = nn.Sequential(
281 nn.Linear(2, 32), nn.ReLU(),
282 nn.Linear(32, 32), nn.ReLU(),
283 nn.Linear(32, 128), nn.ReLU(),
286 self.fully_connected = nn.Sequential(
292 def forward(self, a, b):
293 a = self.features_a(a).view(a.size(0), -1)
294 b = self.features_b(b).view(b.size(0), -1)
295 x = torch.cat((a, b), 1)
296 return self.fully_connected(x)
298 ######################################################################
300 class NetForSequencePair(nn.Module):
302 def feature_model(self):
305 return nn.Sequential(
306 nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
307 nn.AvgPool1d(pooling_size),
309 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
310 nn.AvgPool1d(pooling_size),
312 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
313 nn.AvgPool1d(pooling_size),
315 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
316 nn.AvgPool1d(pooling_size),
321 super(NetForSequencePair, self).__init__()
326 self.features_a = self.feature_model()
327 self.features_b = self.feature_model()
329 self.fully_connected = nn.Sequential(
330 nn.Linear(2 * self.nc, self.nh),
332 nn.Linear(self.nh, 1)
335 def forward(self, a, b):
336 a = a.view(a.size(0), 1, a.size(1))
337 a = self.features_a(a)
338 a = F.avg_pool1d(a, a.size(2))
340 b = b.view(b.size(0), 1, b.size(1))
341 b = self.features_b(b)
342 b = F.avg_pool1d(b, b.size(2))
344 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
345 return self.fully_connected(x)
347 ######################################################################
349 if args.data == 'image_pair':
350 create_pairs = create_image_pairs
351 model = NetForImagePair()
353 elif args.data == 'image_values_pair':
354 create_pairs = create_image_values_pairs
355 model = NetForImageValuesPair()
357 elif args.data == 'sequence_pair':
358 create_pairs = create_sequences_pairs
359 model = NetForSequencePair()
361 ######################
363 a, b, c = create_pairs()
365 file = open(f'train_{k:02d}.dat', 'w')
366 for i in range(a.size(1)):
367 file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
369 ######################
372 raise Exception('Unknown data ' + args.data)
374 ######################################################################
377 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
381 input_a, input_b, classes = create_pairs(train = True)
383 for e in range(args.nb_epochs):
385 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
387 input_br = input_b[torch.randperm(input_b.size(0))]
391 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
392 input_b.split(args.batch_size),
393 input_br.split(args.batch_size)):
394 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
397 optimizer.zero_grad()
401 acc_mi /= (input_a.size(0) // args.batch_size)
403 print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
407 ######################################################################
410 input_a, input_b, classes = create_pairs(train = False)
412 input_br = input_b[torch.randperm(input_b.size(0))]
416 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
417 input_b.split(args.batch_size),
418 input_br.split(args.batch_size)):
419 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
422 acc_mi /= (input_a.size(0) // args.batch_size)
424 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
426 ######################################################################