3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
15 import torch, torchvision
17 from torch import nn, autograd
18 from torch.nn import functional as F
20 ######################################################################
24 mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
25 p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
30 result = torch.empty(nb).normal_(0, std)
31 result = result + torch.sign(torch.rand(result.size()) - p) / 2
34 ######################################################################
37 log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
40 ######################################################################
42 class PiecewiseLinear(nn.Module):
43 def __init__(self, nb, xmin, xmax):
48 self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
49 mu = math.log((xmax - xmin) / nb)
50 self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
53 y = self.alpha + self.xi.exp().cumsum(0)
54 u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
55 n = u.long().clamp(0, self.nb - 1)
56 a = (u - n).clamp(0, 1)
57 x = (1 - a) * y[n] + a * y[n + 1]
61 ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
63 k = torch.arange(self.nb).view(1, -1)
64 assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
67 x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
70 ######################################################################
77 model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
79 train_input = sample_phi(nb_samples)
81 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
82 criterion = nn.MSELoss()
84 for k in range(nb_epochs):
87 for input in train_input.split(batch_size):
88 input.requires_grad_()
91 derivatives, = autograd.grad(
93 retain_graph = True, create_graph = True
96 loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
102 acc_loss += loss.item()
103 if k%10 == 0: print(k, loss.item())
105 ######################################################################
107 input = torch.linspace(-3, 3, 175)
110 mu_N = torch.exp(LogProba(input, 0))
112 input.requires_grad_()
113 output = model(input)
115 grad = autograd.grad(output.sum(), input)[0]
116 mu_hat = LogProba(output, grad.log()).detach().exp()
118 ######################################################################
121 input = input.detach().numpy()
122 output = output.detach().numpy()
124 mu_hat = mu_hat.numpy()
126 ######################################################################
129 ax = fig.add_subplot(1, 1, 1)
131 # ax.set_ylim(-0.25, 1.25)
134 ax.plot(input, output, '-', color = 'tab:red')
136 filename = 'miniflow_mapping.pdf'
137 print(f'Saving {filename}')
138 fig.savefig(filename, bbox_inches='tight')
142 ######################################################################
144 green_dist = '#bfdfbf'
147 ax = fig.add_subplot(1, 1, 1)
148 # ax.set_xlim(-4.5, 4.5)
149 # ax.set_ylim(-0.1, 1.1)
150 lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
151 lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
152 ax.add_collection(lc)
155 ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
156 ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
158 filename = 'miniflow_flow.pdf'
159 print(f'Saving {filename}')
160 fig.savefig(filename, bbox_inches='tight')
164 ######################################################################
167 ax = fig.add_subplot(1, 1, 1)
170 ax.fill_between(input, 0, mu, color = green_dist)
171 # ax.plot(input, mu, '-', color = 'tab:blue')
172 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
173 ax.plot(input, mu_hat, '-', color = 'tab:red')
175 filename = 'miniflow_dist.pdf'
176 print(f'Saving {filename}')
177 fig.savefig(filename, bbox_inches='tight')
181 ######################################################################
184 ax = fig.add_subplot(1, 1, 1)
187 # ax.plot(input, mu, '-', color = 'tab:blue')
188 ax.fill_between(input, 0, mu, color = green_dist)
189 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
191 filename = 'miniflow_target_dist.pdf'
192 print(f'Saving {filename}')
193 fig.savefig(filename, bbox_inches='tight')
197 ######################################################################
199 # z = torch.empty(200).normal_()
200 # z = z[(z > -3) * (z < 3)]
202 # x = model.invert(z)
205 # ax = fig.add_subplot(1, 1, 1)
206 # ax.set_xlim(-4.5, 4.5)
207 # ax.set_ylim(-0.1, 1.1)
208 # lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
209 # lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
210 # ax.add_collection(lc)
213 # # ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
214 # # ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
216 # filename = 'miniflow_synth.pdf'
217 # print(f'Saving {filename}')
218 # fig.savefig(filename, bbox_inches='tight')