OCD update.
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 import argparse, math, sys
4
5 import torch, torchvision
6
7 from torch import nn
8
9 ######################################################################
10
11 if torch.cuda.is_available():
12     device = torch.device('cuda')
13 else:
14     device = torch.device('cpu')
15
16 ######################################################################
17
18 parser = argparse.ArgumentParser(
19     description = 'An implementation of Mutual Information estimator with a deep model',
20     formatter_class = argparse.ArgumentDefaultsHelpFormatter
21 )
22
23 parser.add_argument('--data',
24                     type = str, default = 'image_pair',
25                     help = 'What data')
26
27 parser.add_argument('--seed',
28                     type = int, default = 0,
29                     help = 'Random seed (default 0, < 0 is no seeding)')
30
31 parser.add_argument('--mnist_classes',
32                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
33                     help = 'What MNIST classes to use')
34
35 ######################################################################
36
37 def entropy(target):
38     probas = []
39     for k in range(target.max() + 1):
40         n = (target == k).sum().item()
41         if n > 0: probas.append(n)
42     probas = torch.tensor(probas).float()
43     probas /= probas.sum()
44     return - (probas * probas.log()).sum().item()
45
46 ######################################################################
47
48 args = parser.parse_args()
49
50 if args.seed >= 0:
51     torch.manual_seed(args.seed)
52
53 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
54
55 ######################################################################
56
57 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
58 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
59 train_target = train_set.train_labels.to(device)
60
61 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
62 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
63 test_target = test_set.test_labels.to(device)
64
65 mu, std = train_input.mean(), train_input.std()
66 train_input.sub_(mu).div_(std)
67 test_input.sub_(mu).div_(std)
68
69 ######################################################################
70
71 # Returns a triplet of tensors (a, b, c), where a and b contain each
72 # half of the samples, with a[i] and b[i] of same class for any i, and
73 # c is a 1d long tensor real classes
74
75 def create_image_pairs(train = False):
76     ua, ub = [], []
77
78     if train:
79         input, target = train_input, train_target
80     else:
81         input, target = test_input, test_target
82
83     for i in used_MNIST_classes:
84         used_indices = torch.arange(input.size(0), device = target.device)\
85                             .masked_select(target == i.item())
86         x = input[used_indices]
87         x = x[torch.randperm(x.size(0))]
88         hs = x.size(0)//2
89         ua.append(x.narrow(0, 0, hs))
90         ub.append(x.narrow(0, hs, hs))
91         uc.append(target[used_indices])
92
93     a = torch.cat(ua, 0)
94     b = torch.cat(ub, 0)
95     c = torch.cat(uc, 0)
96     perm = torch.randperm(a.size(0))
97     a = a[perm].contiguous()
98     b = b[perm].contiguous()
99
100     return a, b, c
101
102 ######################################################################
103
104 # Returns a triplet a, b, c where a are the standard MNIST images, c
105 # the classes, and b is a Nx2 tensor, eith for every n:
106 #
107 #   b[n, 0] ~ Uniform(0, 10)
108 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
109
110 def create_image_values_pairs(train = False):
111     ua, ub = [], []
112
113     if train:
114         input, target = train_input, train_target
115     else:
116         input, target = test_input, test_target
117
118     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
119     m[used_MNIST_classes] = 1
120     m = m[target]
121     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
122
123     input = input[used_indices].contiguous()
124     target = target[used_indices].contiguous()
125
126     a = input
127     c = target
128
129     b = a.new(a.size(0), 2)
130     b[:, 0].uniform_(10)
131     b[:, 1].uniform_(0.5)
132     b[:, 1] += b[:, 0] + target.float()
133
134     return a, b, c
135
136 ######################################################################
137
138 class NetForImagePair(nn.Module):
139     def __init__(self):
140         super(NetForImagePair, self).__init__()
141         self.features_a = nn.Sequential(
142             nn.Conv2d(1, 16, kernel_size = 5),
143             nn.MaxPool2d(3), nn.ReLU(),
144             nn.Conv2d(16, 32, kernel_size = 5),
145             nn.MaxPool2d(2), nn.ReLU(),
146         )
147
148         self.features_b = nn.Sequential(
149             nn.Conv2d(1, 16, kernel_size = 5),
150             nn.MaxPool2d(3), nn.ReLU(),
151             nn.Conv2d(16, 32, kernel_size = 5),
152             nn.MaxPool2d(2), nn.ReLU(),
153         )
154
155         self.fully_connected = nn.Sequential(
156             nn.Linear(256, 200),
157             nn.ReLU(),
158             nn.Linear(200, 1)
159         )
160
161     def forward(self, a, b):
162         a = self.features_a(a).view(a.size(0), -1)
163         b = self.features_b(b).view(b.size(0), -1)
164         x = torch.cat((a, b), 1)
165         return self.fully_connected(x)
166
167 ######################################################################
168
169 class NetForImageValuesPair(nn.Module):
170     def __init__(self):
171         super(NetForImageValuesPair, self).__init__()
172         self.features_a = nn.Sequential(
173             nn.Conv2d(1, 16, kernel_size = 5),
174             nn.MaxPool2d(3), nn.ReLU(),
175             nn.Conv2d(16, 32, kernel_size = 5),
176             nn.MaxPool2d(2), nn.ReLU(),
177         )
178
179         self.features_b = nn.Sequential(
180             nn.Linear(2, 32), nn.ReLU(),
181             nn.Linear(32, 32), nn.ReLU(),
182             nn.Linear(32, 128), nn.ReLU(),
183         )
184
185         self.fully_connected = nn.Sequential(
186             nn.Linear(256, 200),
187             nn.ReLU(),
188             nn.Linear(200, 1)
189         )
190
191     def forward(self, a, b):
192         a = self.features_a(a).view(a.size(0), -1)
193         b = self.features_b(b).view(b.size(0), -1)
194         x = torch.cat((a, b), 1)
195         return self.fully_connected(x)
196
197 ######################################################################
198
199 if args.data == 'image_pair':
200     create_pairs = create_image_pairs
201     model = NetForImagePair()
202 elif args.data == 'image_values_pair':
203     create_pairs = create_image_values_pairs
204     model = NetForImageValuesPair()
205 else:
206     raise Exception('Unknown data ' + args.data)
207
208 ######################################################################
209
210 nb_epochs, batch_size = 50, 100
211
212 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
213
214 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
215
216 model.to(device)
217
218 for e in range(nb_epochs):
219
220     input_a, input_b, classes = create_pairs(train = True)
221
222     input_br = input_b[torch.randperm(input_b.size(0))]
223
224     acc_mi = 0.0
225
226     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
227                                           input_b.split(batch_size),
228                                           input_br.split(batch_size)):
229         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
230         loss = - mi
231         acc_mi += mi.item()
232         optimizer.zero_grad()
233         loss.backward()
234         optimizer.step()
235
236     acc_mi /= (input_a.size(0) // batch_size)
237
238     print('%d %.04f %.04f' % (e, acc_mi / math.log(2), entropy(classes) / math.log(2)))
239
240     sys.stdout.flush()
241
242 ######################################################################
243
244 input_a, input_b, classes = create_pairs(train = False)
245
246 for e in range(nb_epochs):
247     input_br = input_b[torch.randperm(input_b.size(0))]
248
249     acc_mi = 0.0
250
251     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
252                                           input_b.split(batch_size),
253                                           input_br.split(batch_size)):
254         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
255         acc_mi += mi.item()
256
257     acc_mi /= (input_a.size(0) // batch_size)
258
259 print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
260
261 ######################################################################