projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
elbo.py
diff --git
a/elbo.py
b/elbo.py
index
6af4a77
..
dbea3b5
100755
(executable)
--- a/
elbo.py
+++ b/
elbo.py
@@
-7,8
+7,10
@@
import torch
import torch
+
def D_KL(a, b):
def D_KL(a, b):
- return - a @ (b / a).log()
+ return -a @ (b / a).log()
+
# p(X = x, Z = z) = p[x, z]
# p(X = x, Z = z) = p[x, z]
@@
-19,12
+21,12
@@
q_XZ /= q_XZ.sum()
p_X = p_XZ.sum(1)
p_Z = p_XZ.sum(0)
p_X = p_XZ.sum(1)
p_Z = p_XZ.sum(0)
-p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim
=
True)
-p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim
=
True)
+p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim
=
True)
+p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim
=
True)
-#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
-q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim
=
True)
+#
q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
+q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim
=
True)
for x in range(p_XZ.size(0)):
for x in range(p_XZ.size(0)):
- elbo = q_Z_given_X[x, :] @ (
p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
+ elbo = q_Z_given_X[x, :] @ (p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))
print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))