Update.
[pytorch.git] / miniflow.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
10 import numpy as np
11
12 import math
13 from math import pi
14
15 import torch, torchvision
16
17 from torch import nn, autograd
18 from torch.nn import functional as F
19
20 ######################################################################
21
22
23 def phi(x):
24     p, std = 0.3, 0.2
25     mu = (1 - p) * torch.exp(
26         LogProba((x - 0.5) / std, math.log(1 / std))
27     ) + p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
28     return mu
29
30
31 def sample_phi(nb):
32     p, std = 0.3, 0.2
33     result = torch.empty(nb).normal_(0, std)
34     result = result + torch.sign(torch.rand(result.size()) - p) / 2
35     return result
36
37
38 ######################################################################
39
40
41 def LogProba(x, ldj):
42     log_p = ldj - 0.5 * (x**2 + math.log(2 * pi))
43     return log_p
44
45
46 ######################################################################
47
48
49 class PiecewiseLinear(nn.Module):
50     def __init__(self, nb, xmin, xmax):
51         super().__init__()
52         self.xmin = xmin
53         self.xmax = xmax
54         self.nb = nb
55         self.alpha = nn.Parameter(torch.tensor([xmin], dtype=torch.float))
56         mu = math.log((xmax - xmin) / nb)
57         self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
58
59     def forward(self, x):
60         y = self.alpha + self.xi.exp().cumsum(0)
61         u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
62         n = u.long().clamp(0, self.nb - 1)
63         a = (u - n).clamp(0, 1)
64         x = (1 - a) * y[n] + a * y[n + 1]
65         return x
66
67     def invert(self, y):
68         ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
69         yy = y.view(-1, 1)
70         k = torch.arange(self.nb).view(1, -1)
71         assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
72         yk = ys[:, :-1]
73         ykp1 = ys[:, 1:]
74         x = self.xmin + (self.xmax - self.xmin) / self.nb * (
75             (yy >= yk) * (yy < ykp1).long() * (k + (yy - yk) / (ykp1 - yk))
76         ).sum(1)
77         return x
78
79
80 ######################################################################
81 # Training
82
83 nb_samples = 25000
84 nb_epochs = 250
85 batch_size = 100
86
87 model = PiecewiseLinear(nb=1001, xmin=-4, xmax=4)
88
89 train_input = sample_phi(nb_samples)
90
91 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
92 criterion = nn.MSELoss()
93
94 for k in range(nb_epochs):
95     acc_loss = 0
96
97     for input in train_input.split(batch_size):
98         input.requires_grad_()
99         output = model(input)
100
101         (derivatives,) = autograd.grad(
102             output.sum(), input, retain_graph=True, create_graph=True
103         )
104
105         loss = (0.5 * (output**2 + math.log(2 * pi)) - derivatives.log()).mean()
106
107         optimizer.zero_grad()
108         loss.backward()
109         optimizer.step()
110
111         acc_loss += loss.item()
112     if k % 10 == 0:
113         print(k, loss.item())
114
115 ######################################################################
116
117 input = torch.linspace(-3, 3, 175)
118
119 mu = phi(input)
120 mu_N = torch.exp(LogProba(input, 0))
121
122 input.requires_grad_()
123 output = model(input)
124
125 grad = autograd.grad(output.sum(), input)[0]
126 mu_hat = LogProba(output, grad.log()).detach().exp()
127
128 ######################################################################
129 # FIGURES
130
131 input = input.detach().numpy()
132 output = output.detach().numpy()
133 mu = mu.numpy()
134 mu_hat = mu_hat.numpy()
135
136 ######################################################################
137
138 fig = plt.figure()
139 ax = fig.add_subplot(1, 1, 1)
140 # ax.set_xlim(-5, 5)
141 # ax.set_ylim(-0.25, 1.25)
142 # ax.axis('off')
143
144 ax.plot(input, output, "-", color="tab:red")
145
146 filename = "miniflow_mapping.pdf"
147 print(f"Saving {filename}")
148 fig.savefig(filename, bbox_inches="tight")
149
150 # plt.show()
151
152 ######################################################################
153
154 green_dist = "#bfdfbf"
155
156 fig = plt.figure()
157 ax = fig.add_subplot(1, 1, 1)
158 # ax.set_xlim(-4.5, 4.5)
159 # ax.set_ylim(-0.1, 1.1)
160 lines = list(
161     ([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output)
162 )
163 lc = mc.LineCollection(lines, color="tab:red", linewidth=0.1)
164 ax.add_collection(lc)
165 ax.axis("off")
166
167 ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color=green_dist)
168 ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color=green_dist)
169
170 filename = "miniflow_flow.pdf"
171 print(f"Saving {filename}")
172 fig.savefig(filename, bbox_inches="tight")
173
174 # plt.show()
175
176 ######################################################################
177
178 fig = plt.figure()
179 ax = fig.add_subplot(1, 1, 1)
180 ax.axis("off")
181
182 ax.fill_between(input, 0, mu, color=green_dist)
183 # ax.plot(input, mu, '-', color = 'tab:blue')
184 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
185 ax.plot(input, mu_hat, "-", color="tab:red")
186
187 filename = "miniflow_dist.pdf"
188 print(f"Saving {filename}")
189 fig.savefig(filename, bbox_inches="tight")
190
191 # plt.show()
192
193 ######################################################################
194
195 fig = plt.figure()
196 ax = fig.add_subplot(1, 1, 1)
197 ax.axis("off")
198
199 # ax.plot(input, mu, '-', color = 'tab:blue')
200 ax.fill_between(input, 0, mu, color=green_dist)
201 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
202
203 filename = "miniflow_target_dist.pdf"
204 print(f"Saving {filename}")
205 fig.savefig(filename, bbox_inches="tight")
206
207 # plt.show()
208
209 ######################################################################
210
211 # z = torch.empty(200).normal_()
212 # z = z[(z > -3) * (z < 3)]
213
214 # x = model.invert(z)
215
216 # fig = plt.figure()
217 # ax = fig.add_subplot(1, 1, 1)
218 # ax.set_xlim(-4.5, 4.5)
219 # ax.set_ylim(-0.1, 1.1)
220 # lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
221 # lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
222 # ax.add_collection(lc)
223 # # ax.axis('off')
224
225 # # ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
226 # # ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
227
228 # filename = 'miniflow_synth.pdf'
229 # print(f'Saving {filename}')
230 # fig.savefig(filename, bbox_inches='tight')
231
232 # # plt.show()