From 47525ec795faca1ab72aee13956a553d070c81b7 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 14 Mar 2022 13:22:13 +0100 Subject: [PATCH] Update. --- ddpol.py | 9 +++++---- elbo.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/ddpol.py b/ddpol.py index f33b0a1..645f47c 100755 --- a/ddpol.py +++ b/ddpol.py @@ -50,12 +50,13 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) - l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values + W = torch.linalg.eig(beta) + l, U = W.eigenvalues.real, W.eigenvectors.real + Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - return torch.lstsq(B, M).solution[:D+1, 0] + return torch.linalg.lstsq(M, B).solution[:D+1] ###################################################################### @@ -99,7 +100,7 @@ ax.set_ylabel('MSE', labelpad = 10) ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5, linestyle = '--') -ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', +ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples', fontsize = 10, color = 'gray', rotation = 90, rotation_mode='anchor') diff --git a/elbo.py b/elbo.py index 24155fe..6af4a77 100755 --- a/elbo.py +++ b/elbo.py @@ -7,23 +7,24 @@ import torch -def D_KL(p, q): - return - p @ (q / p).log() +def D_KL(a, b): + return - a @ (b / a).log() # p(X = x, Z = z) = p[x, z] -p = torch.rand(5, 4) -p /= p.sum() -q = torch.rand(p.size()) -q /= q.sum() +p_XZ = torch.rand(5, 4) +p_XZ /= p_XZ.sum() +q_XZ = torch.rand(p_XZ.size()) +q_XZ /= q_XZ.sum() -p_X = p.sum(1) -p_Z = p.sum(0) -p_X_given_Z = p / p.sum(0, keepdim = True) -p_Z_given_X = p / p.sum(1, keepdim = True) -q_X_given_Z = q / q.sum(0, keepdim = True) -q_Z_given_X = q / q.sum(1, keepdim = True) +p_X = p_XZ.sum(1) +p_Z = p_XZ.sum(0) +p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True) +p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True) -for x in range(p.size(0)): +#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True) +q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True) + +for x in range(p_XZ.size(0)): elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :])) -- 2.39.5