3 import argparse, math, sys
4 from copy import deepcopy
6 import torch, torchvision
9 import torch.nn.functional as F
11 ######################################################################
13 if torch.cuda.is_available():
14 torch.backends.cudnn.benchmark = True
15 device = torch.device('cuda')
17 device = torch.device('cpu')
19 ######################################################################
21 parser = argparse.ArgumentParser(
22 description = 'An implementation of Mutual Information estimator with a deep model',
23 formatter_class = argparse.ArgumentDefaultsHelpFormatter
26 parser.add_argument('--data',
27 type = str, default = 'image_pair',
30 parser.add_argument('--seed',
31 type = int, default = 0,
32 help = 'Random seed (default 0, < 0 is no seeding)')
34 parser.add_argument('--mnist_classes',
35 type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
36 help = 'What MNIST classes to use')
38 parser.add_argument('--nb_classes',
39 type = int, default = 2,
40 help = 'How many classes for sequences')
42 parser.add_argument('--nb_epochs',
43 type = int, default = 50,
44 help = 'How many epochs')
46 parser.add_argument('--batch_size',
47 type = int, default = 100,
50 ######################################################################
54 for k in range(target.max() + 1):
55 n = (target == k).sum().item()
56 if n > 0: probas.append(n)
57 probas = torch.tensor(probas).float()
58 probas /= probas.sum()
59 return - (probas * probas.log()).sum().item()
61 def robust_log_mean_exp(x):
63 # return (x-a).exp().mean().log() + a
65 return x.exp().mean().log()
67 ######################################################################
69 args = parser.parse_args()
72 torch.manual_seed(args.seed)
74 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
76 ######################################################################
78 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
79 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
80 train_target = train_set.train_labels.to(device)
82 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
83 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
84 test_target = test_set.test_labels.to(device)
86 mu, std = train_input.mean(), train_input.std()
87 train_input.sub_(mu).div_(std)
88 test_input.sub_(mu).div_(std)
90 ######################################################################
92 # Returns a triplet of tensors (a, b, c), where a and b contain each
93 # half of the samples, with a[i] and b[i] of same class for any i, and
94 # c is a 1d long tensor real classes
96 def create_image_pairs(train = False):
97 ua, ub, uc = [], [], []
100 input, target = train_input, train_target
102 input, target = test_input, test_target
104 for i in used_MNIST_classes:
105 used_indices = torch.arange(input.size(0), device = target.device)\
106 .masked_select(target == i.item())
107 x = input[used_indices]
108 x = x[torch.randperm(x.size(0))]
110 ua.append(x.narrow(0, 0, hs))
111 ub.append(x.narrow(0, hs, hs))
112 uc.append(target[used_indices])
117 perm = torch.randperm(a.size(0))
118 a = a[perm].contiguous()
119 b = b[perm].contiguous()
123 ######################################################################
125 # Returns a triplet a, b, c where a are the standard MNIST images, c
126 # the classes, and b is a Nx2 tensor, eith for every n:
128 # b[n, 0] ~ Uniform(0, 10)
129 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
131 def create_image_values_pairs(train = False):
135 input, target = train_input, train_target
137 input, target = test_input, test_target
139 m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
140 m[used_MNIST_classes] = 1
142 used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
144 input = input[used_indices].contiguous()
145 target = target[used_indices].contiguous()
150 b = a.new(a.size(0), 2)
151 b[:, 0].uniform_(0.0, 10.0)
152 b[:, 1].uniform_(0.0, 0.5)
153 b[:, 1] += b[:, 0] + target.float()
157 ######################################################################
159 def create_sequences_pairs(train = False):
160 nb, length = 10000, 1024
163 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
164 # hb = torch.randint(args.nb_classes, (nb, ), device = device)
167 pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
168 a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
169 a = a - pos.view(nb, 1)
170 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
171 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
172 noise = a.new(a.size()).normal_(0, noise_level)
175 pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
176 b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
177 b1 = b1 - pos.view(nb, 1)
178 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
179 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
180 b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
181 b2 = b2 - pos.view(nb, 1)
182 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
185 noise = b.new(b.size()).normal_(0, noise_level)
188 # a = (a - a.mean()) / a.std()
189 # b = (b - b.mean()) / b.std()
193 ######################################################################
195 class NetForImagePair(nn.Module):
197 super(NetForImagePair, self).__init__()
198 self.features_a = nn.Sequential(
199 nn.Conv2d(1, 16, kernel_size = 5),
200 nn.MaxPool2d(3), nn.ReLU(),
201 nn.Conv2d(16, 32, kernel_size = 5),
202 nn.MaxPool2d(2), nn.ReLU(),
205 self.features_b = nn.Sequential(
206 nn.Conv2d(1, 16, kernel_size = 5),
207 nn.MaxPool2d(3), nn.ReLU(),
208 nn.Conv2d(16, 32, kernel_size = 5),
209 nn.MaxPool2d(2), nn.ReLU(),
212 self.fully_connected = nn.Sequential(
218 def forward(self, a, b):
219 a = self.features_a(a).view(a.size(0), -1)
220 b = self.features_b(b).view(b.size(0), -1)
221 x = torch.cat((a, b), 1)
222 return self.fully_connected(x)
224 ######################################################################
226 class NetForImageValuesPair(nn.Module):
228 super(NetForImageValuesPair, self).__init__()
229 self.features_a = nn.Sequential(
230 nn.Conv2d(1, 16, kernel_size = 5),
231 nn.MaxPool2d(3), nn.ReLU(),
232 nn.Conv2d(16, 32, kernel_size = 5),
233 nn.MaxPool2d(2), nn.ReLU(),
236 self.features_b = nn.Sequential(
237 nn.Linear(2, 32), nn.ReLU(),
238 nn.Linear(32, 32), nn.ReLU(),
239 nn.Linear(32, 128), nn.ReLU(),
242 self.fully_connected = nn.Sequential(
248 def forward(self, a, b):
249 a = self.features_a(a).view(a.size(0), -1)
250 b = self.features_b(b).view(b.size(0), -1)
251 x = torch.cat((a, b), 1)
252 return self.fully_connected(x)
254 ######################################################################
256 class NetForSequencePair(nn.Module):
258 def feature_model(self):
261 return nn.Sequential(
262 nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
263 nn.AvgPool1d(pooling_size),
265 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
266 nn.AvgPool1d(pooling_size),
268 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
269 nn.AvgPool1d(pooling_size),
271 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
272 nn.AvgPool1d(pooling_size),
277 super(NetForSequencePair, self).__init__()
282 self.features_a = self.feature_model()
283 self.features_b = self.feature_model()
285 self.fully_connected = nn.Sequential(
286 nn.Linear(2 * self.nc, self.nh),
288 nn.Linear(self.nh, 1)
291 def forward(self, a, b):
292 a = a.view(a.size(0), 1, a.size(1))
293 a = self.features_a(a)
294 a = F.avg_pool1d(a, a.size(2))
296 b = b.view(b.size(0), 1, b.size(1))
297 b = self.features_b(b)
298 b = F.avg_pool1d(b, b.size(2))
300 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
301 return self.fully_connected(x)
303 ######################################################################
305 if args.data == 'image_pair':
306 create_pairs = create_image_pairs
307 model = NetForImagePair()
308 elif args.data == 'image_values_pair':
309 create_pairs = create_image_values_pairs
310 model = NetForImageValuesPair()
311 elif args.data == 'sequence_pair':
312 create_pairs = create_sequences_pairs
313 model = NetForSequencePair()
314 ######################################################################
315 a, b, c = create_pairs()
317 file = open(f'/tmp/train_{k:02d}.dat', 'w')
318 for i in range(a.size(1)):
319 file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
322 ######################################################################
324 raise Exception('Unknown data ' + args.data)
326 ######################################################################
328 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
332 for e in range(args.nb_epochs):
334 input_a, input_b, classes = create_pairs(train = True)
336 input_br = input_b[torch.randperm(input_b.size(0))]
340 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
342 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
343 input_b.split(args.batch_size),
344 input_br.split(args.batch_size)):
345 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
348 optimizer.zero_grad()
352 acc_mi /= (input_a.size(0) // args.batch_size)
354 print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
358 ######################################################################
360 input_a, input_b, classes = create_pairs(train = False)
362 input_br = input_b[torch.randperm(input_b.size(0))]
366 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
367 input_b.split(args.batch_size),
368 input_br.split(args.batch_size)):
369 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
372 acc_mi /= (input_a.size(0) // args.batch_size)
374 print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
376 ######################################################################