Update.
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 import argparse, math, sys
4 from copy import deepcopy
5
6 import torch, torchvision
7
8 from torch import nn
9 import torch.nn.functional as F
10
11 ######################################################################
12
13 if torch.cuda.is_available():
14     device = torch.device('cuda')
15     torch.backends.cudnn.benchmark = True
16 else:
17     device = torch.device('cpu')
18
19 ######################################################################
20
21 parser = argparse.ArgumentParser(
22     description = 'An implementation of Mutual Information estimator with a deep model',
23     formatter_class = argparse.ArgumentDefaultsHelpFormatter
24 )
25
26 parser.add_argument('--data',
27                     type = str, default = 'image_pair',
28                     help = 'What data')
29
30 parser.add_argument('--seed',
31                     type = int, default = 0,
32                     help = 'Random seed (default 0, < 0 is no seeding)')
33
34 parser.add_argument('--mnist_classes',
35                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
36                     help = 'What MNIST classes to use')
37
38 ######################################################################
39
40 def entropy(target):
41     probas = []
42     for k in range(target.max() + 1):
43         n = (target == k).sum().item()
44         if n > 0: probas.append(n)
45     probas = torch.tensor(probas).float()
46     probas /= probas.sum()
47     return - (probas * probas.log()).sum().item()
48
49 ######################################################################
50
51 args = parser.parse_args()
52
53 if args.seed >= 0:
54     torch.manual_seed(args.seed)
55
56 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
57
58 ######################################################################
59
60 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
61 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
62 train_target = train_set.train_labels.to(device)
63
64 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
65 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
66 test_target = test_set.test_labels.to(device)
67
68 mu, std = train_input.mean(), train_input.std()
69 train_input.sub_(mu).div_(std)
70 test_input.sub_(mu).div_(std)
71
72 ######################################################################
73
74 # Returns a triplet of tensors (a, b, c), where a and b contain each
75 # half of the samples, with a[i] and b[i] of same class for any i, and
76 # c is a 1d long tensor real classes
77
78 def create_image_pairs(train = False):
79     ua, ub, uc = [], [], []
80
81     if train:
82         input, target = train_input, train_target
83     else:
84         input, target = test_input, test_target
85
86     for i in used_MNIST_classes:
87         used_indices = torch.arange(input.size(0), device = target.device)\
88                             .masked_select(target == i.item())
89         x = input[used_indices]
90         x = x[torch.randperm(x.size(0))]
91         hs = x.size(0)//2
92         ua.append(x.narrow(0, 0, hs))
93         ub.append(x.narrow(0, hs, hs))
94         uc.append(target[used_indices])
95
96     a = torch.cat(ua, 0)
97     b = torch.cat(ub, 0)
98     c = torch.cat(uc, 0)
99     perm = torch.randperm(a.size(0))
100     a = a[perm].contiguous()
101     b = b[perm].contiguous()
102
103     return a, b, c
104
105 ######################################################################
106
107 # Returns a triplet a, b, c where a are the standard MNIST images, c
108 # the classes, and b is a Nx2 tensor, eith for every n:
109 #
110 #   b[n, 0] ~ Uniform(0, 10)
111 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
112
113 def create_image_values_pairs(train = False):
114     ua, ub = [], []
115
116     if train:
117         input, target = train_input, train_target
118     else:
119         input, target = test_input, test_target
120
121     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
122     m[used_MNIST_classes] = 1
123     m = m[target]
124     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
125
126     input = input[used_indices].contiguous()
127     target = target[used_indices].contiguous()
128
129     a = input
130     c = target
131
132     b = a.new(a.size(0), 2)
133     b[:, 0].uniform_(10)
134     b[:, 1].uniform_(0.5)
135     b[:, 1] += b[:, 0] + target.float()
136
137     return a, b, c
138
139 ######################################################################
140
141 def create_sequences_pairs(train = False):
142     nb, length = 10000, 1024
143     noise_level = 1e-2
144
145     nb_classes = 4
146     ha = torch.randint(nb_classes, (nb, ), device = device) + 1
147     # hb = torch.randint(nb_classes, (nb, ), device = device)
148     hb = ha
149
150     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
151     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
152     a = a - pos.view(nb, 1)
153     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
154     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + nb_classes)
155     noise = a.new(a.size()).normal_(0, noise_level)
156     a = a + noise
157
158     pos = torch.empty(nb, device = device).uniform_(0.5)
159     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
160     b1 = b1 - pos.view(nb, 1)
161     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1)
162     pos = pos + hb.float() / (nb_classes + 1) * 0.5
163     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
164     b2 = b2 - pos.view(nb, 1)
165     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1)
166
167     b = b1 + b2
168     noise = b.new(b.size()).normal_(0, noise_level)
169     b = b + noise
170
171     ######################################################################
172     # for k in range(10):
173         # file = open(f'/tmp/dat{k:02d}', 'w')
174         # for i in range(a.size(1)):
175             # file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
176         # file.close()
177     # exit(0)
178     ######################################################################
179
180     a = (a - a.mean()) / a.std()
181     b = (b - b.mean()) / b.std()
182
183     return a, b, ha
184
185 ######################################################################
186
187 class NetForImagePair(nn.Module):
188     def __init__(self):
189         super(NetForImagePair, self).__init__()
190         self.features_a = nn.Sequential(
191             nn.Conv2d(1, 16, kernel_size = 5),
192             nn.MaxPool2d(3), nn.ReLU(),
193             nn.Conv2d(16, 32, kernel_size = 5),
194             nn.MaxPool2d(2), nn.ReLU(),
195         )
196
197         self.features_b = nn.Sequential(
198             nn.Conv2d(1, 16, kernel_size = 5),
199             nn.MaxPool2d(3), nn.ReLU(),
200             nn.Conv2d(16, 32, kernel_size = 5),
201             nn.MaxPool2d(2), nn.ReLU(),
202         )
203
204         self.fully_connected = nn.Sequential(
205             nn.Linear(256, 200),
206             nn.ReLU(),
207             nn.Linear(200, 1)
208         )
209
210     def forward(self, a, b):
211         a = self.features_a(a).view(a.size(0), -1)
212         b = self.features_b(b).view(b.size(0), -1)
213         x = torch.cat((a, b), 1)
214         return self.fully_connected(x)
215
216 ######################################################################
217
218 class NetForImageValuesPair(nn.Module):
219     def __init__(self):
220         super(NetForImageValuesPair, self).__init__()
221         self.features_a = nn.Sequential(
222             nn.Conv2d(1, 16, kernel_size = 5),
223             nn.MaxPool2d(3), nn.ReLU(),
224             nn.Conv2d(16, 32, kernel_size = 5),
225             nn.MaxPool2d(2), nn.ReLU(),
226         )
227
228         self.features_b = nn.Sequential(
229             nn.Linear(2, 32), nn.ReLU(),
230             nn.Linear(32, 32), nn.ReLU(),
231             nn.Linear(32, 128), nn.ReLU(),
232         )
233
234         self.fully_connected = nn.Sequential(
235             nn.Linear(256, 200),
236             nn.ReLU(),
237             nn.Linear(200, 1)
238         )
239
240     def forward(self, a, b):
241         a = self.features_a(a).view(a.size(0), -1)
242         b = self.features_b(b).view(b.size(0), -1)
243         x = torch.cat((a, b), 1)
244         return self.fully_connected(x)
245
246 ######################################################################
247
248 class NetForSequencePair(nn.Module):
249
250     def feature_model(self):
251         return  nn.Sequential(
252             nn.Conv1d(1, self.nc, kernel_size = 5),
253             nn.MaxPool1d(2), nn.ReLU(),
254             nn.Conv1d(self.nc, self.nc, kernel_size = 5),
255             nn.MaxPool1d(2), nn.ReLU(),
256             nn.Conv1d(self.nc, self.nc, kernel_size = 5),
257             nn.MaxPool1d(2), nn.ReLU(),
258             nn.Conv1d(self.nc, self.nc, kernel_size = 5),
259             nn.MaxPool1d(2), nn.ReLU(),
260             nn.Conv1d(self.nc, self.nc, kernel_size = 5),
261             nn.MaxPool1d(2), nn.ReLU(),
262         )
263
264     def __init__(self):
265         super(NetForSequencePair, self).__init__()
266
267         self.nc = 32
268         self.nh = 256
269
270         self.features_a = self.feature_model()
271         self.features_b = self.feature_model()
272
273         self.fully_connected = nn.Sequential(
274             nn.Linear(2 * self.nc, self.nh),
275             nn.ReLU(),
276             nn.Linear(self.nh, 1)
277         )
278
279     def forward(self, a, b):
280         a = a.view(a.size(0), 1, a.size(1))
281         a = self.features_a(a)
282         a = F.avg_pool1d(a, a.size(2))
283
284         b = b.view(b.size(0), 1, b.size(1))
285         b = self.features_b(b)
286         b = F.avg_pool1d(b, b.size(2))
287
288         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
289         return self.fully_connected(x)
290
291 ######################################################################
292
293 if args.data == 'image_pair':
294     create_pairs = create_image_pairs
295     model = NetForImagePair()
296 elif args.data == 'image_values_pair':
297     create_pairs = create_image_values_pairs
298     model = NetForImageValuesPair()
299 elif args.data == 'sequence_pair':
300     create_pairs = create_sequences_pairs
301     model = NetForSequencePair()
302 else:
303     raise Exception('Unknown data ' + args.data)
304
305 ######################################################################
306
307 nb_epochs, batch_size = 50, 100
308
309 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
310
311 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
312
313 model.to(device)
314
315 for e in range(nb_epochs):
316
317     input_a, input_b, classes = create_pairs(train = True)
318
319     input_br = input_b[torch.randperm(input_b.size(0))]
320
321     acc_mi = 0.0
322
323     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
324                                           input_b.split(batch_size),
325                                           input_br.split(batch_size)):
326         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
327         loss = - mi
328         acc_mi += mi.item()
329         optimizer.zero_grad()
330         loss.backward()
331         optimizer.step()
332
333     acc_mi /= (input_a.size(0) // batch_size)
334
335     print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
336
337     sys.stdout.flush()
338
339 ######################################################################
340
341 input_a, input_b, classes = create_pairs(train = False)
342
343 for e in range(nb_epochs):
344     input_br = input_b[torch.randperm(input_b.size(0))]
345
346     acc_mi = 0.0
347
348     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
349                                           input_b.split(batch_size),
350                                           input_br.split(batch_size)):
351         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
352         acc_mi += mi.item()
353
354     acc_mi /= (input_a.size(0) // batch_size)
355
356 print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
357
358 ######################################################################