import torch
-def D_KL(p, q):
- return - p @ (q / p).log()
+def D_KL(a, b):
+ return - a @ (b / a).log()
# p(X = x, Z = z) = p[x, z]
-p = torch.rand(5, 4)
-p /= p.sum()
-q = torch.rand(p.size())
-q /= q.sum()
+p_XZ = torch.rand(5, 4)
+p_XZ /= p_XZ.sum()
+q_XZ = torch.rand(p_XZ.size())
+q_XZ /= q_XZ.sum()
-p_X = p.sum(1)
-p_Z = p.sum(0)
-p_X_given_Z = p / p.sum(0, keepdim = True)
-p_Z_given_X = p / p.sum(1, keepdim = True)
-q_X_given_Z = q / q.sum(0, keepdim = True)
-q_Z_given_X = q / q.sum(1, keepdim = True)
+p_X = p_XZ.sum(1)
+p_Z = p_XZ.sum(0)
+p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True)
+p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True)
-for x in range(p.size(0)):
+#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
+q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True)
+
+for x in range(p_XZ.size(0)):
elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))