######################################################################
-# START_LOG_PROBA
def LogProba(x, ldj):
log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
return log_p
-# END_LOG_PROBA
######################################################################
-# START_MODEL
class PiecewiseLinear(nn.Module):
def __init__(self, nb, xmin, xmax):
- super(PiecewiseLinear, self).__init__()
+ super().__init__()
self.xmin = xmin
self.xmax = xmax
self.nb = nb
a = (u - n).clamp(0, 1)
x = (1 - a) * y[n] + a * y[n + 1]
return x
-# END_MODEL
def invert(self, y):
ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
for k in range(nb_epochs):
acc_loss = 0
-# START_OPTIMIZATION
for input in train_input.split(batch_size):
input.requires_grad_()
output = model(input)
optimizer.zero_grad()
loss.backward()
optimizer.step()
-# END_OPTIMIZATION
acc_loss += loss.item()
if k%10 == 0: print(k, loss.item())
######################################################################
-z = torch.empty(200).normal_()
-z = z[(z > -3) * (z < 3)]
+# z = torch.empty(200).normal_()
+# z = z[(z > -3) * (z < 3)]
-x = model.invert(z)
+# x = model.invert(z)
-fig = plt.figure()
-ax = fig.add_subplot(1, 1, 1)
-ax.set_xlim(-4.5, 4.5)
-ax.set_ylim(-0.1, 1.1)
-lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
-lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
-ax.add_collection(lc)
-# ax.axis('off')
-
-# ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
-# ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
+# fig = plt.figure()
+# ax = fig.add_subplot(1, 1, 1)
+# ax.set_xlim(-4.5, 4.5)
+# ax.set_ylim(-0.1, 1.1)
+# lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
+# lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+# ax.add_collection(lc)
+# # ax.axis('off')
-filename = 'miniflow_synth.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+# # ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist)
+# # ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist)
-# plt.show()
+# filename = 'miniflow_synth.pdf'
+# print(f'Saving {filename}')
+# fig.savefig(filename, bbox_inches='tight')
+# # plt.show()