3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
15 import torch, torchvision
17 from torch import nn, autograd
18 from torch.nn import functional as F
20 ######################################################################
24 mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
25 p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
30 result = torch.empty(nb).normal_(0, std)
31 result = result + torch.sign(torch.rand(result.size()) - p) / 2
34 ######################################################################
38 log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
42 ######################################################################
45 class PiecewiseLinear(nn.Module):
46 def __init__(self, nb, xmin, xmax):
47 super(PiecewiseLinear, self).__init__()
51 self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
52 mu = math.log((xmax - xmin) / nb)
53 self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
56 y = self.alpha + self.xi.exp().cumsum(0)
57 u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
58 n = u.long().clamp(0, self.nb - 1)
59 a = (u - n).clamp(0, 1)
60 x = (1 - a) * y[n] + a * y[n + 1]
65 ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
67 k = torch.arange(self.nb).view(1, -1)
68 assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
71 x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
74 ######################################################################
81 model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
83 train_input = sample_phi(nb_samples)
85 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
86 criterion = nn.MSELoss()
88 for k in range(nb_epochs):
92 for input in train_input.split(batch_size):
93 input.requires_grad_()
96 derivatives, = autograd.grad(
98 retain_graph = True, create_graph = True
101 loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
103 optimizer.zero_grad()
108 acc_loss += loss.item()
109 if k%10 == 0: print(k, loss.item())
111 ######################################################################
113 input = torch.linspace(-3, 3, 175)
116 mu_N = torch.exp(LogProba(input, 0))
118 input.requires_grad_()
119 output = model(input)
121 grad = autograd.grad(output.sum(), input)[0]
122 mu_hat = LogProba(output, grad.log()).detach().exp()
124 ######################################################################
127 input = input.detach().numpy()
128 output = output.detach().numpy()
130 mu_hat = mu_hat.numpy()
132 ######################################################################
135 ax = fig.add_subplot(1, 1, 1)
137 # ax.set_ylim(-0.25, 1.25)
140 ax.plot(input, output, '-', color = 'tab:red')
142 filename = 'miniflow_mapping.pdf'
143 print(f'Saving {filename}')
144 fig.savefig(filename, bbox_inches='tight')
148 ######################################################################
150 green_dist = '#bfdfbf'
153 ax = fig.add_subplot(1, 1, 1)
154 # ax.set_xlim(-4.5, 4.5)
155 # ax.set_ylim(-0.1, 1.1)
156 lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
157 lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
158 ax.add_collection(lc)
161 ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
162 ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
164 filename = 'miniflow_flow.pdf'
165 print(f'Saving {filename}')
166 fig.savefig(filename, bbox_inches='tight')
170 ######################################################################
173 ax = fig.add_subplot(1, 1, 1)
176 ax.fill_between(input, 0, mu, color = green_dist)
177 # ax.plot(input, mu, '-', color = 'tab:blue')
178 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
179 ax.plot(input, mu_hat, '-', color = 'tab:red')
181 filename = 'miniflow_dist.pdf'
182 print(f'Saving {filename}')
183 fig.savefig(filename, bbox_inches='tight')
187 ######################################################################
190 ax = fig.add_subplot(1, 1, 1)
193 # ax.plot(input, mu, '-', color = 'tab:blue')
194 ax.fill_between(input, 0, mu, color = green_dist)
195 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
197 filename = 'miniflow_target_dist.pdf'
198 print(f'Saving {filename}')
199 fig.savefig(filename, bbox_inches='tight')
203 ######################################################################
205 # z = torch.empty(200).normal_()
206 # z = z[(z > -3) * (z < 3)]
208 # x = model.invert(z)
211 # ax = fig.add_subplot(1, 1, 1)
212 # ax.set_xlim(-4.5, 4.5)
213 # ax.set_ylim(-0.1, 1.1)
214 # lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
215 # lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
216 # ax.add_collection(lc)
219 # # ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
220 # # ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
222 # filename = 'miniflow_synth.pdf'
223 # print(f'Saving {filename}')
224 # fig.savefig(filename, bbox_inches='tight')