3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
15 import torch, torchvision
17 from torch import nn, autograd
18 from torch.nn import functional as F
20 ######################################################################
25 mu = (1 - p) * torch.exp(
26 LogProba((x - 0.5) / std, math.log(1 / std))
27 ) + p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
33 result = torch.empty(nb).normal_(0, std)
34 result = result + torch.sign(torch.rand(result.size()) - p) / 2
38 ######################################################################
42 log_p = ldj - 0.5 * (x**2 + math.log(2 * pi))
46 ######################################################################
49 class PiecewiseLinear(nn.Module):
50 def __init__(self, nb, xmin, xmax):
55 self.alpha = nn.Parameter(torch.tensor([xmin], dtype=torch.float))
56 mu = math.log((xmax - xmin) / nb)
57 self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
60 y = self.alpha + self.xi.exp().cumsum(0)
61 u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
62 n = u.long().clamp(0, self.nb - 1)
63 a = (u - n).clamp(0, 1)
64 x = (1 - a) * y[n] + a * y[n + 1]
68 ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
70 k = torch.arange(self.nb).view(1, -1)
71 assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
74 x = self.xmin + (self.xmax - self.xmin) / self.nb * (
75 (yy >= yk) * (yy < ykp1).long() * (k + (yy - yk) / (ykp1 - yk))
80 ######################################################################
87 model = PiecewiseLinear(nb=1001, xmin=-4, xmax=4)
89 train_input = sample_phi(nb_samples)
91 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
92 criterion = nn.MSELoss()
94 for k in range(nb_epochs):
97 for input in train_input.split(batch_size):
98 input.requires_grad_()
101 (derivatives,) = autograd.grad(
102 output.sum(), input, retain_graph=True, create_graph=True
105 loss = (0.5 * (output**2 + math.log(2 * pi)) - derivatives.log()).mean()
107 optimizer.zero_grad()
111 acc_loss += loss.item()
113 print(k, loss.item())
115 ######################################################################
117 input = torch.linspace(-3, 3, 175)
120 mu_N = torch.exp(LogProba(input, 0))
122 input.requires_grad_()
123 output = model(input)
125 grad = autograd.grad(output.sum(), input)[0]
126 mu_hat = LogProba(output, grad.log()).detach().exp()
128 ######################################################################
131 input = input.detach().numpy()
132 output = output.detach().numpy()
134 mu_hat = mu_hat.numpy()
136 ######################################################################
139 ax = fig.add_subplot(1, 1, 1)
141 # ax.set_ylim(-0.25, 1.25)
144 ax.plot(input, output, "-", color="tab:red")
146 filename = "miniflow_mapping.pdf"
147 print(f"Saving {filename}")
148 fig.savefig(filename, bbox_inches="tight")
152 ######################################################################
154 green_dist = "#bfdfbf"
157 ax = fig.add_subplot(1, 1, 1)
158 # ax.set_xlim(-4.5, 4.5)
159 # ax.set_ylim(-0.1, 1.1)
161 ([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output)
163 lc = mc.LineCollection(lines, color="tab:red", linewidth=0.1)
164 ax.add_collection(lc)
167 ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color=green_dist)
168 ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color=green_dist)
170 filename = "miniflow_flow.pdf"
171 print(f"Saving {filename}")
172 fig.savefig(filename, bbox_inches="tight")
176 ######################################################################
179 ax = fig.add_subplot(1, 1, 1)
182 ax.fill_between(input, 0, mu, color=green_dist)
183 # ax.plot(input, mu, '-', color = 'tab:blue')
184 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
185 ax.plot(input, mu_hat, "-", color="tab:red")
187 filename = "miniflow_dist.pdf"
188 print(f"Saving {filename}")
189 fig.savefig(filename, bbox_inches="tight")
193 ######################################################################
196 ax = fig.add_subplot(1, 1, 1)
199 # ax.plot(input, mu, '-', color = 'tab:blue')
200 ax.fill_between(input, 0, mu, color=green_dist)
201 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
203 filename = "miniflow_target_dist.pdf"
204 print(f"Saving {filename}")
205 fig.savefig(filename, bbox_inches="tight")
209 ######################################################################
211 # z = torch.empty(200).normal_()
212 # z = z[(z > -3) * (z < 3)]
214 # x = model.invert(z)
217 # ax = fig.add_subplot(1, 1, 1)
218 # ax.set_xlim(-4.5, 4.5)
219 # ax.set_ylim(-0.1, 1.1)
220 # lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
221 # lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
222 # ax.add_collection(lc)
225 # # ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
226 # # ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
228 # filename = 'miniflow_synth.pdf'
229 # print(f'Saving {filename}')
230 # fig.savefig(filename, bbox_inches='tight')