Update.
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 import argparse, math, sys
4 from copy import deepcopy
5
6 import torch, torchvision
7
8 from torch import nn
9 import torch.nn.functional as F
10
11 ######################################################################
12
13 if torch.cuda.is_available():
14     torch.backends.cudnn.benchmark = True
15     device = torch.device('cuda')
16 else:
17     device = torch.device('cpu')
18
19 ######################################################################
20
21 parser = argparse.ArgumentParser(
22     description = 'An implementation of Mutual Information estimator with a deep model',
23     formatter_class = argparse.ArgumentDefaultsHelpFormatter
24 )
25
26 parser.add_argument('--data',
27                     type = str, default = 'image_pair',
28                     help = 'What data')
29
30 parser.add_argument('--seed',
31                     type = int, default = 0,
32                     help = 'Random seed (default 0, < 0 is no seeding)')
33
34 parser.add_argument('--mnist_classes',
35                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
36                     help = 'What MNIST classes to use')
37
38 parser.add_argument('--nb_classes',
39                     type = int, default = 2,
40                     help = 'How many classes for sequences')
41
42 parser.add_argument('--nb_epochs',
43                     type = int, default = 50,
44                     help = 'How many epochs')
45
46 parser.add_argument('--batch_size',
47                     type = int, default = 100,
48                     help = 'Batch size')
49
50 parser.add_argument('--independent', action = 'store_true',
51                     help = 'Should the pair components be independent')
52
53 ######################################################################
54
55 def entropy(target):
56     probas = []
57     for k in range(target.max() + 1):
58         n = (target == k).sum().item()
59         if n > 0: probas.append(n)
60     probas = torch.tensor(probas).float()
61     probas /= probas.sum()
62     return - (probas * probas.log()).sum().item()
63
64 def robust_log_mean_exp(x):
65     # a = x.max()
66     # return (x-a).exp().mean().log() + a
67     # a = x.max()
68     return x.exp().mean().log()
69
70 ######################################################################
71
72 args = parser.parse_args()
73
74 if args.seed >= 0:
75     torch.manual_seed(args.seed)
76
77 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
78
79 ######################################################################
80
81 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
82 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
83 train_target = train_set.train_labels.to(device)
84
85 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
86 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
87 test_target = test_set.test_labels.to(device)
88
89 mu, std = train_input.mean(), train_input.std()
90 train_input.sub_(mu).div_(std)
91 test_input.sub_(mu).div_(std)
92
93 ######################################################################
94
95 # Returns a triplet of tensors (a, b, c), where a and b contain each
96 # half of the samples, with a[i] and b[i] of same class for any i, and
97 # c is a 1d long tensor real classes
98
99 def create_image_pairs(train = False):
100     ua, ub, uc = [], [], []
101
102     if train:
103         input, target = train_input, train_target
104     else:
105         input, target = test_input, test_target
106
107     for i in used_MNIST_classes:
108         used_indices = torch.arange(input.size(0), device = target.device)\
109                             .masked_select(target == i.item())
110         x = input[used_indices]
111         x = x[torch.randperm(x.size(0))]
112         hs = x.size(0)//2
113         ua.append(x.narrow(0, 0, hs))
114         ub.append(x.narrow(0, hs, hs))
115         uc.append(target[used_indices])
116
117     a = torch.cat(ua, 0)
118     b = torch.cat(ub, 0)
119     c = torch.cat(uc, 0)
120     perm = torch.randperm(a.size(0))
121     a = a[perm].contiguous()
122
123     if args.independent:
124         perm = torch.randperm(a.size(0))
125     b = b[perm].contiguous()
126
127     return a, b, c
128
129 ######################################################################
130
131 # Returns a triplet a, b, c where a are the standard MNIST images, c
132 # the classes, and b is a Nx2 tensor, eith for every n:
133 #
134 #   b[n, 0] ~ Uniform(0, 10)
135 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
136
137 def create_image_values_pairs(train = False):
138     ua, ub = [], []
139
140     if train:
141         input, target = train_input, train_target
142     else:
143         input, target = test_input, test_target
144
145     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
146     m[used_MNIST_classes] = 1
147     m = m[target]
148     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
149
150     input = input[used_indices].contiguous()
151     target = target[used_indices].contiguous()
152
153     a = input
154     c = target
155
156     b = a.new(a.size(0), 2)
157     b[:, 0].uniform_(0.0, 10.0)
158     b[:, 1].uniform_(0.0, 0.5)
159
160     if args.independent:
161         b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
162     else:
163         b[:, 1] += b[:, 0] + target.float()
164
165     return a, b, c
166
167 ######################################################################
168
169 def create_sequences_pairs(train = False):
170     nb, length = 10000, 1024
171     noise_level = 2e-2
172
173     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
174     if args.independent:
175         hb = torch.randint(args.nb_classes, (nb, ), device = device)
176     else:
177         hb = ha
178
179     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
180     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
181     a = a - pos.view(nb, 1)
182     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
183     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
184     noise = a.new(a.size()).normal_(0, noise_level)
185     a = a + noise
186
187     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
188     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
189     b1 = b1 - pos.view(nb, 1)
190     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
191     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
192     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
193     b2 = b2 - pos.view(nb, 1)
194     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
195
196     b = b1 + b2
197     noise = b.new(b.size()).normal_(0, noise_level)
198     b = b + noise
199
200     # a = (a - a.mean()) / a.std()
201     # b = (b - b.mean()) / b.std()
202
203     return a, b, ha
204
205 ######################################################################
206
207 class NetForImagePair(nn.Module):
208     def __init__(self):
209         super(NetForImagePair, self).__init__()
210         self.features_a = nn.Sequential(
211             nn.Conv2d(1, 16, kernel_size = 5),
212             nn.MaxPool2d(3), nn.ReLU(),
213             nn.Conv2d(16, 32, kernel_size = 5),
214             nn.MaxPool2d(2), nn.ReLU(),
215         )
216
217         self.features_b = nn.Sequential(
218             nn.Conv2d(1, 16, kernel_size = 5),
219             nn.MaxPool2d(3), nn.ReLU(),
220             nn.Conv2d(16, 32, kernel_size = 5),
221             nn.MaxPool2d(2), nn.ReLU(),
222         )
223
224         self.fully_connected = nn.Sequential(
225             nn.Linear(256, 200),
226             nn.ReLU(),
227             nn.Linear(200, 1)
228         )
229
230     def forward(self, a, b):
231         a = self.features_a(a).view(a.size(0), -1)
232         b = self.features_b(b).view(b.size(0), -1)
233         x = torch.cat((a, b), 1)
234         return self.fully_connected(x)
235
236 ######################################################################
237
238 class NetForImageValuesPair(nn.Module):
239     def __init__(self):
240         super(NetForImageValuesPair, self).__init__()
241         self.features_a = nn.Sequential(
242             nn.Conv2d(1, 16, kernel_size = 5),
243             nn.MaxPool2d(3), nn.ReLU(),
244             nn.Conv2d(16, 32, kernel_size = 5),
245             nn.MaxPool2d(2), nn.ReLU(),
246         )
247
248         self.features_b = nn.Sequential(
249             nn.Linear(2, 32), nn.ReLU(),
250             nn.Linear(32, 32), nn.ReLU(),
251             nn.Linear(32, 128), nn.ReLU(),
252         )
253
254         self.fully_connected = nn.Sequential(
255             nn.Linear(256, 200),
256             nn.ReLU(),
257             nn.Linear(200, 1)
258         )
259
260     def forward(self, a, b):
261         a = self.features_a(a).view(a.size(0), -1)
262         b = self.features_b(b).view(b.size(0), -1)
263         x = torch.cat((a, b), 1)
264         return self.fully_connected(x)
265
266 ######################################################################
267
268 class NetForSequencePair(nn.Module):
269
270     def feature_model(self):
271         kernel_size = 11
272         pooling_size = 4
273         return  nn.Sequential(
274             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
275             nn.AvgPool1d(pooling_size),
276             nn.LeakyReLU(),
277             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
278             nn.AvgPool1d(pooling_size),
279             nn.LeakyReLU(),
280             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
281             nn.AvgPool1d(pooling_size),
282             nn.LeakyReLU(),
283             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
284             nn.AvgPool1d(pooling_size),
285             nn.LeakyReLU(),
286         )
287
288     def __init__(self):
289         super(NetForSequencePair, self).__init__()
290
291         self.nc = 32
292         self.nh = 256
293
294         self.features_a = self.feature_model()
295         self.features_b = self.feature_model()
296
297         self.fully_connected = nn.Sequential(
298             nn.Linear(2 * self.nc, self.nh),
299             nn.ReLU(),
300             nn.Linear(self.nh, 1)
301         )
302
303     def forward(self, a, b):
304         a = a.view(a.size(0), 1, a.size(1))
305         a = self.features_a(a)
306         a = F.avg_pool1d(a, a.size(2))
307
308         b = b.view(b.size(0), 1, b.size(1))
309         b = self.features_b(b)
310         b = F.avg_pool1d(b, b.size(2))
311
312         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
313         return self.fully_connected(x)
314
315 ######################################################################
316
317 if args.data == 'image_pair':
318     create_pairs = create_image_pairs
319     model = NetForImagePair()
320 elif args.data == 'image_values_pair':
321     create_pairs = create_image_values_pairs
322     model = NetForImageValuesPair()
323 elif args.data == 'sequence_pair':
324     create_pairs = create_sequences_pairs
325     model = NetForSequencePair()
326     ######################################################################
327     a, b, c = create_pairs()
328     for k in range(10):
329         file = open(f'train_{k:02d}.dat', 'w')
330         for i in range(a.size(1)):
331             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
332         file.close()
333     # exit(0)
334     ######################################################################
335 else:
336     raise Exception('Unknown data ' + args.data)
337
338 ######################################################################
339
340 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
341
342 model.to(device)
343
344 for e in range(args.nb_epochs):
345
346     input_a, input_b, classes = create_pairs(train = True)
347
348     input_br = input_b[torch.randperm(input_b.size(0))]
349
350     acc_mi = 0.0
351
352     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
353
354     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
355                                           input_b.split(args.batch_size),
356                                           input_br.split(args.batch_size)):
357         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
358         acc_mi += mi.item()
359         loss = - mi
360         optimizer.zero_grad()
361         loss.backward()
362         optimizer.step()
363
364     acc_mi /= (input_a.size(0) // args.batch_size)
365
366     print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
367
368     sys.stdout.flush()
369
370 ######################################################################
371
372 input_a, input_b, classes = create_pairs(train = False)
373
374 input_br = input_b[torch.randperm(input_b.size(0))]
375
376 acc_mi = 0.0
377
378 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
379                                       input_b.split(args.batch_size),
380                                       input_br.split(args.batch_size)):
381     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
382     acc_mi += mi.item()
383
384 acc_mi /= (input_a.size(0) // args.batch_size)
385
386 print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
387
388 ######################################################################