Update.
[pytorch] / miniflow.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
10 import numpy as np
11
12 import math
13 from math import pi
14
15 import torch, torchvision
16
17 from torch import nn, autograd
18 from torch.nn import functional as F
19
20 ######################################################################
21
22 def phi(x):
23     p, std = 0.3, 0.2
24     mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
25               p  * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
26     return mu
27
28 def sample_phi(nb):
29     p, std = 0.3, 0.2
30     result = torch.empty(nb).normal_(0, std)
31     result = result + torch.sign(torch.rand(result.size()) - p) / 2
32     return result
33
34 ######################################################################
35
36 # START_LOG_PROBA
37 def LogProba(x, ldj):
38     log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
39     return log_p
40 # END_LOG_PROBA
41
42 ######################################################################
43
44 # START_MODEL
45 class PiecewiseLinear(nn.Module):
46     def __init__(self, nb, xmin, xmax):
47         super(PiecewiseLinear, self).__init__()
48         self.xmin = xmin
49         self.xmax = xmax
50         self.nb = nb
51         self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
52         mu = math.log((xmax - xmin) / nb)
53         self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
54
55     def forward(self, x):
56         y = self.alpha + self.xi.exp().cumsum(0)
57         u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
58         n = u.long().clamp(0, self.nb - 1)
59         a = (u - n).clamp(0, 1)
60         x = (1 - a) * y[n] + a * y[n + 1]
61         return x
62 # END_MODEL
63
64     def invert(self, y):
65         ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
66         yy = y.view(-1, 1)
67         k = torch.arange(self.nb).view(1, -1)
68         assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
69         yk = ys[:, :-1]
70         ykp1 = ys[:, 1:]
71         x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
72         return x
73
74 ######################################################################
75 # Training
76
77 nb_samples = 25000
78 nb_epochs = 250
79 batch_size = 100
80
81 model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
82
83 train_input = sample_phi(nb_samples)
84
85 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
86 criterion = nn.MSELoss()
87
88 for k in range(nb_epochs):
89     acc_loss = 0
90
91 # START_OPTIMIZATION
92     for input in train_input.split(batch_size):
93         input.requires_grad_()
94         output = model(input)
95
96         derivatives, = autograd.grad(
97             output.sum(), input,
98             retain_graph = True, create_graph = True
99         )
100
101         loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
102
103         optimizer.zero_grad()
104         loss.backward()
105         optimizer.step()
106 # END_OPTIMIZATION
107
108         acc_loss += loss.item()
109     if k%10 == 0: print(k, loss.item())
110
111 ######################################################################
112
113 input = torch.linspace(-3, 3, 175)
114
115 mu = phi(input)
116 mu_N = torch.exp(LogProba(input, 0))
117
118 input.requires_grad_()
119 output = model(input)
120
121 grad = autograd.grad(output.sum(), input)[0]
122 mu_hat = LogProba(output, grad.log()).detach().exp()
123
124 ######################################################################
125 # FIGURES
126
127 input = input.detach().numpy()
128 output = output.detach().numpy()
129 mu = mu.numpy()
130 mu_hat = mu_hat.numpy()
131
132 ######################################################################
133
134 fig = plt.figure()
135 ax = fig.add_subplot(1, 1, 1)
136 # ax.set_xlim(-5, 5)
137 # ax.set_ylim(-0.25, 1.25)
138 # ax.axis('off')
139
140 ax.plot(input, output, '-', color = 'tab:red')
141
142 filename = 'miniflow_mapping.pdf'
143 print(f'Saving {filename}')
144 fig.savefig(filename, bbox_inches='tight')
145
146 # plt.show()
147
148 ######################################################################
149
150 green_dist = '#bfdfbf'
151
152 fig = plt.figure()
153 ax = fig.add_subplot(1, 1, 1)
154 # ax.set_xlim(-4.5, 4.5)
155 # ax.set_ylim(-0.1, 1.1)
156 lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
157 lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
158 ax.add_collection(lc)
159 ax.axis('off')
160
161 ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
162 ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
163
164 filename = 'miniflow_flow.pdf'
165 print(f'Saving {filename}')
166 fig.savefig(filename, bbox_inches='tight')
167
168 # plt.show()
169
170 ######################################################################
171
172 fig = plt.figure()
173 ax = fig.add_subplot(1, 1, 1)
174 ax.axis('off')
175
176 ax.fill_between(input, 0, mu, color = green_dist)
177 # ax.plot(input, mu, '-', color = 'tab:blue')
178 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
179 ax.plot(input, mu_hat, '-', color = 'tab:red')
180
181 filename = 'miniflow_dist.pdf'
182 print(f'Saving {filename}')
183 fig.savefig(filename, bbox_inches='tight')
184
185 # plt.show()
186
187 ######################################################################
188
189 fig = plt.figure()
190 ax = fig.add_subplot(1, 1, 1)
191 ax.axis('off')
192
193 # ax.plot(input, mu, '-', color = 'tab:blue')
194 ax.fill_between(input, 0, mu, color = green_dist)
195 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
196
197 filename = 'miniflow_target_dist.pdf'
198 print(f'Saving {filename}')
199 fig.savefig(filename, bbox_inches='tight')
200
201 # plt.show()
202
203 ######################################################################
204
205 z = torch.empty(200).normal_()
206 z = z[(z > -3) * (z < 3)]
207
208 x = model.invert(z)
209
210 fig = plt.figure()
211 ax = fig.add_subplot(1, 1, 1)
212 ax.set_xlim(-4.5, 4.5)
213 ax.set_ylim(-0.1, 1.1)
214 lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
215 lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
216 ax.add_collection(lc)
217 # ax.axis('off')
218
219 # ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
220 # ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
221
222 filename = 'miniflow_synth.pdf'
223 print(f'Saving {filename}')
224 fig.savefig(filename, bbox_inches='tight')
225
226 # plt.show()
227