projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
39ce2c3
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 1 Mar 2024 23:36:04 +0000
(
00:36
+0100)
committer
François Fleuret
<francois@fleuret.org>
Fri, 1 Mar 2024 23:36:04 +0000
(
00:36
+0100)
tiny_vae.py
patch
|
blob
|
history
diff --git
a/tiny_vae.py
b/tiny_vae.py
index
0895830
..
784f775
100755
(executable)
--- a/
tiny_vae.py
+++ b/
tiny_vae.py
@@
-28,9
+28,9
@@
parser = argparse.ArgumentParser(
description="Very simple implementation of a VAE for teaching."
)
description="Very simple implementation of a VAE for teaching."
)
-parser.add_argument("--nb_epochs", type=int, default=
25
)
+parser.add_argument("--nb_epochs", type=int, default=
100
)
-parser.add_argument("--learning_rate", type=float, default=
1e-3
)
+parser.add_argument("--learning_rate", type=float, default=
2e-4
)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=100)
@@
-44,12
+44,6
@@
parser.add_argument("--nb_channels", type=int, default=128)
parser.add_argument("--no_dkl", action="store_true")
parser.add_argument("--no_dkl", action="store_true")
-# With that option, do not follow the setup of the original VAE paper
-# of forcing the variance of X|Z to 1 during training and to 0 for
-# sampling, but optimize and use the variance.
-
-parser.add_argument("--no_hacks", action="store_true")
-
args = parser.parse_args()
log_file = open(args.log_filename, "w")
args = parser.parse_args()
log_file = open(args.log_filename, "w")
@@
-145,8
+139,7
@@
class ImageGivenLatentNet(nn.Module):
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
- if not args.no_hacks:
- log_var[...] = 0
+ # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
return mu, log_var
return mu, log_var
@@
-239,20
+232,14
@@
save_image(x, "input.png")
mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
- x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
- x = mean_p_X_given_z
+x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
save_image(x, "output.png")
# Generate a bunch of images
z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
save_image(x, "output.png")
# Generate a bunch of images
z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
- x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
- x = mean_p_X_given_z
+x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
save_image(x, "synth.png")
######################################################################
save_image(x, "synth.png")
######################################################################