0c485b20ad4689f6652bbf4da6078f789f6e8950
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 import argparse, math, sys
4 from copy import deepcopy
5
6 import torch, torchvision
7
8 from torch import nn
9 import torch.nn.functional as F
10
11 ######################################################################
12
13 if torch.cuda.is_available():
14     torch.backends.cudnn.benchmark = True
15     device = torch.device('cuda')
16 else:
17     device = torch.device('cpu')
18
19 ######################################################################
20
21 parser = argparse.ArgumentParser(
22     description = 'An implementation of Mutual Information estimator with a deep model',
23     formatter_class = argparse.ArgumentDefaultsHelpFormatter
24 )
25
26 parser.add_argument('--data',
27                     type = str, default = 'image_pair',
28                     help = 'What data')
29
30 parser.add_argument('--seed',
31                     type = int, default = 0,
32                     help = 'Random seed (default 0, < 0 is no seeding)')
33
34 parser.add_argument('--mnist_classes',
35                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
36                     help = 'What MNIST classes to use')
37
38 parser.add_argument('--nb_classes',
39                     type = int, default = 2,
40                     help = 'How many classes for sequences')
41
42 parser.add_argument('--nb_epochs',
43                     type = int, default = 50,
44                     help = 'How many epochs')
45
46 parser.add_argument('--batch_size',
47                     type = int, default = 100,
48                     help = 'Batch size')
49
50 ######################################################################
51
52 def entropy(target):
53     probas = []
54     for k in range(target.max() + 1):
55         n = (target == k).sum().item()
56         if n > 0: probas.append(n)
57     probas = torch.tensor(probas).float()
58     probas /= probas.sum()
59     return - (probas * probas.log()).sum().item()
60
61 def robust_log_mean_exp(x):
62     # a = x.max()
63     # return (x-a).exp().mean().log() + a
64     # a = x.max()
65     return x.exp().mean().log()
66
67 ######################################################################
68
69 args = parser.parse_args()
70
71 if args.seed >= 0:
72     torch.manual_seed(args.seed)
73
74 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
75
76 ######################################################################
77
78 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
79 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
80 train_target = train_set.train_labels.to(device)
81
82 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
83 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
84 test_target = test_set.test_labels.to(device)
85
86 mu, std = train_input.mean(), train_input.std()
87 train_input.sub_(mu).div_(std)
88 test_input.sub_(mu).div_(std)
89
90 ######################################################################
91
92 # Returns a triplet of tensors (a, b, c), where a and b contain each
93 # half of the samples, with a[i] and b[i] of same class for any i, and
94 # c is a 1d long tensor real classes
95
96 def create_image_pairs(train = False):
97     ua, ub, uc = [], [], []
98
99     if train:
100         input, target = train_input, train_target
101     else:
102         input, target = test_input, test_target
103
104     for i in used_MNIST_classes:
105         used_indices = torch.arange(input.size(0), device = target.device)\
106                             .masked_select(target == i.item())
107         x = input[used_indices]
108         x = x[torch.randperm(x.size(0))]
109         hs = x.size(0)//2
110         ua.append(x.narrow(0, 0, hs))
111         ub.append(x.narrow(0, hs, hs))
112         uc.append(target[used_indices])
113
114     a = torch.cat(ua, 0)
115     b = torch.cat(ub, 0)
116     c = torch.cat(uc, 0)
117     perm = torch.randperm(a.size(0))
118     a = a[perm].contiguous()
119     b = b[perm].contiguous()
120
121     return a, b, c
122
123 ######################################################################
124
125 # Returns a triplet a, b, c where a are the standard MNIST images, c
126 # the classes, and b is a Nx2 tensor, eith for every n:
127 #
128 #   b[n, 0] ~ Uniform(0, 10)
129 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
130
131 def create_image_values_pairs(train = False):
132     ua, ub = [], []
133
134     if train:
135         input, target = train_input, train_target
136     else:
137         input, target = test_input, test_target
138
139     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
140     m[used_MNIST_classes] = 1
141     m = m[target]
142     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
143
144     input = input[used_indices].contiguous()
145     target = target[used_indices].contiguous()
146
147     a = input
148     c = target
149
150     b = a.new(a.size(0), 2)
151     b[:, 0].uniform_(0.0, 10.0)
152     b[:, 1].uniform_(0.0, 0.5)
153     b[:, 1] += b[:, 0] + target.float()
154
155     return a, b, c
156
157 ######################################################################
158
159 def create_sequences_pairs(train = False):
160     nb, length = 10000, 1024
161     noise_level = 2e-2
162
163     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
164     # hb = torch.randint(args.nb_classes, (nb, ), device = device)
165     hb = ha
166
167     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
168     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
169     a = a - pos.view(nb, 1)
170     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
171     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
172     noise = a.new(a.size()).normal_(0, noise_level)
173     a = a + noise
174
175     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
176     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
177     b1 = b1 - pos.view(nb, 1)
178     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
179     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
180     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
181     b2 = b2 - pos.view(nb, 1)
182     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
183
184     b = b1 + b2
185     noise = b.new(b.size()).normal_(0, noise_level)
186     b = b + noise
187
188     # a = (a - a.mean()) / a.std()
189     # b = (b - b.mean()) / b.std()
190
191     return a, b, ha
192
193 ######################################################################
194
195 class NetForImagePair(nn.Module):
196     def __init__(self):
197         super(NetForImagePair, self).__init__()
198         self.features_a = nn.Sequential(
199             nn.Conv2d(1, 16, kernel_size = 5),
200             nn.MaxPool2d(3), nn.ReLU(),
201             nn.Conv2d(16, 32, kernel_size = 5),
202             nn.MaxPool2d(2), nn.ReLU(),
203         )
204
205         self.features_b = nn.Sequential(
206             nn.Conv2d(1, 16, kernel_size = 5),
207             nn.MaxPool2d(3), nn.ReLU(),
208             nn.Conv2d(16, 32, kernel_size = 5),
209             nn.MaxPool2d(2), nn.ReLU(),
210         )
211
212         self.fully_connected = nn.Sequential(
213             nn.Linear(256, 200),
214             nn.ReLU(),
215             nn.Linear(200, 1)
216         )
217
218     def forward(self, a, b):
219         a = self.features_a(a).view(a.size(0), -1)
220         b = self.features_b(b).view(b.size(0), -1)
221         x = torch.cat((a, b), 1)
222         return self.fully_connected(x)
223
224 ######################################################################
225
226 class NetForImageValuesPair(nn.Module):
227     def __init__(self):
228         super(NetForImageValuesPair, self).__init__()
229         self.features_a = nn.Sequential(
230             nn.Conv2d(1, 16, kernel_size = 5),
231             nn.MaxPool2d(3), nn.ReLU(),
232             nn.Conv2d(16, 32, kernel_size = 5),
233             nn.MaxPool2d(2), nn.ReLU(),
234         )
235
236         self.features_b = nn.Sequential(
237             nn.Linear(2, 32), nn.ReLU(),
238             nn.Linear(32, 32), nn.ReLU(),
239             nn.Linear(32, 128), nn.ReLU(),
240         )
241
242         self.fully_connected = nn.Sequential(
243             nn.Linear(256, 200),
244             nn.ReLU(),
245             nn.Linear(200, 1)
246         )
247
248     def forward(self, a, b):
249         a = self.features_a(a).view(a.size(0), -1)
250         b = self.features_b(b).view(b.size(0), -1)
251         x = torch.cat((a, b), 1)
252         return self.fully_connected(x)
253
254 ######################################################################
255
256 class NetForSequencePair(nn.Module):
257
258     def feature_model(self):
259         kernel_size = 11
260         pooling_size = 4
261         return  nn.Sequential(
262             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
263             nn.AvgPool1d(pooling_size),
264             nn.LeakyReLU(),
265             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
266             nn.AvgPool1d(pooling_size),
267             nn.LeakyReLU(),
268             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
269             nn.AvgPool1d(pooling_size),
270             nn.LeakyReLU(),
271             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
272             nn.AvgPool1d(pooling_size),
273             nn.LeakyReLU(),
274         )
275
276     def __init__(self):
277         super(NetForSequencePair, self).__init__()
278
279         self.nc = 32
280         self.nh = 256
281
282         self.features_a = self.feature_model()
283         self.features_b = self.feature_model()
284
285         self.fully_connected = nn.Sequential(
286             nn.Linear(2 * self.nc, self.nh),
287             nn.ReLU(),
288             nn.Linear(self.nh, 1)
289         )
290
291     def forward(self, a, b):
292         a = a.view(a.size(0), 1, a.size(1))
293         a = self.features_a(a)
294         a = F.avg_pool1d(a, a.size(2))
295
296         b = b.view(b.size(0), 1, b.size(1))
297         b = self.features_b(b)
298         b = F.avg_pool1d(b, b.size(2))
299
300         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
301         return self.fully_connected(x)
302
303 ######################################################################
304
305 if args.data == 'image_pair':
306     create_pairs = create_image_pairs
307     model = NetForImagePair()
308 elif args.data == 'image_values_pair':
309     create_pairs = create_image_values_pairs
310     model = NetForImageValuesPair()
311 elif args.data == 'sequence_pair':
312     create_pairs = create_sequences_pairs
313     model = NetForSequencePair()
314     ######################################################################
315     a, b, c = create_pairs()
316     for k in range(10):
317         file = open(f'/tmp/train_{k:02d}.dat', 'w')
318         for i in range(a.size(1)):
319             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
320         file.close()
321     # exit(0)
322     ######################################################################
323 else:
324     raise Exception('Unknown data ' + args.data)
325
326 ######################################################################
327
328 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
329
330 model.to(device)
331
332 for e in range(args.nb_epochs):
333
334     input_a, input_b, classes = create_pairs(train = True)
335
336     input_br = input_b[torch.randperm(input_b.size(0))]
337
338     acc_mi = 0.0
339
340     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
341
342     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
343                                           input_b.split(args.batch_size),
344                                           input_br.split(args.batch_size)):
345         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
346         acc_mi += mi.item()
347         loss = - mi
348         optimizer.zero_grad()
349         loss.backward()
350         optimizer.step()
351
352     acc_mi /= (input_a.size(0) // args.batch_size)
353
354     print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
355
356     sys.stdout.flush()
357
358 ######################################################################
359
360 input_a, input_b, classes = create_pairs(train = False)
361
362 input_br = input_b[torch.randperm(input_b.size(0))]
363
364 acc_mi = 0.0
365
366 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
367                                       input_b.split(args.batch_size),
368                                       input_br.split(args.batch_size)):
369     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
370     acc_mi += mi.item()
371
372 acc_mi /= (input_a.size(0) // args.batch_size)
373
374 print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
375
376 ######################################################################