From 9529aef0fdc1d1810e3c331a01fdbe640669ecfc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 1 Mar 2024 23:21:21 +0100 Subject: [PATCH] Update. --- tiny_vae.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tiny_vae.py b/tiny_vae.py index 577f717..d33cc4b 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -24,9 +24,13 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") +parser = argparse.ArgumentParser( + description="Very simple implementation of a VAE for teaching." +) + +parser.add_argument("--nb_epochs", type=int, default=25) -parser.add_argument("--nb_epochs", type=int, default=100) +parser.add_argument("--learning_rate", type=float, default=1e-3) parser.add_argument("--batch_size", type=int, default=100) @@ -40,6 +44,11 @@ parser.add_argument("--nb_channels", type=int, default=128) parser.add_argument("--no_dkl", action="store_true") +# With that option, do not follow the setup of the original VAE paper +# of forcing the variance of X|Z to 1 during training and to 0 for +# sampling, but optimize and use the variance. +parser.add_argument("--no_hacks", action="store_true") + args = parser.parse_args() log_file = open(args.log_filename, "w") @@ -135,6 +144,8 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] + if not args.no_hacks: + log_var[...] = 0 return mu, log_var @@ -160,7 +171,7 @@ model_p_X_given_z = ImageGivenLatentNet( optimizer = optim.Adam( itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()), - lr=4e-4, + lr=args.learning_rate, ) model_p_X_given_z.to(device) @@ -227,14 +238,20 @@ save_image(x, "input.png") mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +if args.no_hacks: + x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +else: + x = mean_p_X_given_z save_image(x, "output.png") # Generate a bunch of images z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +if args.no_hacks: + x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +else: + x = mean_p_X_given_z save_image(x, "synth.png") ###################################################################### -- 2.39.5