From ce8deb0b3038c17d3198a447a7a3c44a93606d38 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 2 Mar 2024 00:36:04 +0100 Subject: [PATCH] Update. --- tiny_vae.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/tiny_vae.py b/tiny_vae.py index 0895830..784f775 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -28,9 +28,9 @@ parser = argparse.ArgumentParser( description="Very simple implementation of a VAE for teaching." ) -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=100) -parser.add_argument("--learning_rate", type=float, default=1e-3) +parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--batch_size", type=int, default=100) @@ -44,12 +44,6 @@ parser.add_argument("--nb_channels", type=int, default=128) parser.add_argument("--no_dkl", action="store_true") -# With that option, do not follow the setup of the original VAE paper -# of forcing the variance of X|Z to 1 during training and to 0 for -# sampling, but optimize and use the variance. - -parser.add_argument("--no_hacks", action="store_true") - args = parser.parse_args() log_file = open(args.log_filename, "w") @@ -145,8 +139,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] - if not args.no_hacks: - log_var[...] = 0 + # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1] return mu, log_var @@ -239,20 +232,14 @@ save_image(x, "input.png") mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z +x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) save_image(x, "output.png") # Generate a bunch of images z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z +x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) save_image(x, "synth.png") ###################################################################### -- 2.39.5