Update.
[pytorch.git] / elbo.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9
10
11 def D_KL(a, b):
12     return -a @ (b / a).log()
13
14
15 # p(X = x, Z = z) = p[x, z]
16
17 p_XZ = torch.rand(5, 4)
18 p_XZ /= p_XZ.sum()
19 q_XZ = torch.rand(p_XZ.size())
20 q_XZ /= q_XZ.sum()
21
22 p_X = p_XZ.sum(1)
23 p_Z = p_XZ.sum(0)
24 p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim=True)
25 p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim=True)
26
27 # q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
28 q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim=True)
29
30 for x in range(p_XZ.size(0)):
31     elbo = q_Z_given_X[x, :] @ (p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
32     print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))