projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c951f0b
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 1 Mar 2024 21:34:54 +0000
(22:34 +0100)
committer
François Fleuret
<francois@fleuret.org>
Fri, 1 Mar 2024 21:34:54 +0000
(22:34 +0100)
tiny_vae.py
patch
|
blob
|
history
diff --git
a/tiny_vae.py
b/tiny_vae.py
index
bbdbf1a
..
577f717
100755
(executable)
--- a/
tiny_vae.py
+++ b/
tiny_vae.py
@@
-26,7
+26,7
@@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
-parser.add_argument("--nb_epochs", type=int, default=
25
)
+parser.add_argument("--nb_epochs", type=int, default=
100
)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=100)
@@
-75,13
+75,13
@@
def log_p_gaussian(x, mu, log_var):
)
)
-def dkl_gaussians(m
u_a, log_var_a, mu
_b, log_var_b):
- m
u_a, log_var_a = mu
_a.flatten(1), log_var_a.flatten(1)
- m
u_b, log_var_b = mu
_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(m
ean_a, log_var_a, mean
_b, log_var_b):
+ m
ean_a, log_var_a = mean
_a.flatten(1), log_var_a.flatten(1)
+ m
ean_b, log_var_b = mean
_b.flatten(1), log_var_b.flatten(1)
var_a = log_var_a.exp()
var_b = log_var_b.exp()
return 0.5 * (
var_a = log_var_a.exp()
var_b = log_var_b.exp()
return 0.5 * (
- log_var_b - log_var_a - 1 + (m
u_a - mu
_b).pow(2) / var_b + var_a / var_b
+ log_var_b - log_var_a - 1 + (m
ean_a - mean
_b).pow(2) / var_b + var_a / var_b
).sum(1)
).sum(1)
@@
-176,27
+176,27
@@
test_input.sub_(train_mu).div_(train_std)
######################################################################
######################################################################
-m
u
_p_Z = train_input.new_zeros(1, args.latent_dim)
-log_var_p_Z = m
u
_p_Z
+m
ean
_p_Z = train_input.new_zeros(1, args.latent_dim)
+log_var_p_Z = m
ean
_p_Z
for epoch in range(args.nb_epochs):
acc_loss = 0
for x in train_input.split(args.batch_size):
for epoch in range(args.nb_epochs):
acc_loss = 0
for x in train_input.split(args.batch_size):
- m
u
_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
- z = sample_gaussian(m
u
_q_Z_given_x, log_var_q_Z_given_x)
- m
u
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+ m
ean
_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(m
ean
_q_Z_given_x, log_var_q_Z_given_x)
+ m
ean
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
if args.no_dkl:
if args.no_dkl:
- log_q_z_given_x = log_p_gaussian(z, m
u
_q_Z_given_x, log_var_q_Z_given_x)
+ log_q_z_given_x = log_p_gaussian(z, m
ean
_q_Z_given_x, log_var_q_Z_given_x)
log_p_x_z = log_p_gaussian(
log_p_x_z = log_p_gaussian(
- x, m
u
_p_X_given_z, log_var_p_X_given_z
- ) + log_p_gaussian(z, m
u
_p_Z, log_var_p_Z)
+ x, m
ean
_p_X_given_z, log_var_p_X_given_z
+ ) + log_p_gaussian(z, m
ean
_p_Z, log_var_p_Z)
loss = -(log_p_x_z - log_q_z_given_x).mean()
else:
loss = -(log_p_x_z - log_q_z_given_x).mean()
else:
- log_p_x_given_z = log_p_gaussian(x, m
u
_p_X_given_z, log_var_p_X_given_z)
+ log_p_x_given_z = log_p_gaussian(x, m
ean
_p_X_given_z, log_var_p_X_given_z)
dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
- m
u_q_Z_given_x, log_var_q_Z_given_x, mu
_p_Z, log_var_p_Z
+ m
ean_q_Z_given_x, log_var_q_Z_given_x, mean
_p_Z, log_var_p_Z
)
loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
)
loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
@@
-224,17
+224,17
@@
save_image(x, "input.png")
# Save the same images after encoding / decoding
# Save the same images after encoding / decoding
-m
u
_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(m
u
_q_Z_given_x, log_var_q_Z_given_x)
-m
u
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(m
u
_p_X_given_z, log_var_p_X_given_z)
+m
ean
_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+z = sample_gaussian(m
ean
_q_Z_given_x, log_var_q_Z_given_x)
+m
ean
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(m
ean
_p_X_given_z, log_var_p_X_given_z)
save_image(x, "output.png")
# Generate a bunch of images
save_image(x, "output.png")
# Generate a bunch of images
-z = sample_gaussian(m
u
_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
-m
u
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(m
u
_p_X_given_z, log_var_p_X_given_z)
+z = sample_gaussian(m
ean
_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
+m
ean
_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(m
ean
_p_X_given_z, log_var_p_X_given_z)
save_image(x, "synth.png")
######################################################################
save_image(x, "synth.png")
######################################################################