From dc1e3534151307491a1eacf053fc4aede631448b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 1 Mar 2024 22:34:54 +0100 Subject: [PATCH] Update. --- tiny_vae.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tiny_vae.py b/tiny_vae.py index bbdbf1a..577f717 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -26,7 +26,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=100) @@ -75,13 +75,13 @@ def log_p_gaussian(x, mu, log_var): ) -def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b): - mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1) - mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b): + mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1) + mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( - log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b + log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b ).sum(1) @@ -176,27 +176,27 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mu_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mu_p_Z +mean_p_Z = train_input.new_zeros(1, args.latent_dim) +log_var_p_Z = mean_p_Z for epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) - mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) + mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x) + log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x) log_p_x_z = log_p_gaussian( - x, mu_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z) + x, mean_p_X_given_z, log_var_p_X_given_z + ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z) loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z) + log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z) dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z + mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z ) loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() @@ -224,17 +224,17 @@ save_image(x, "input.png") # Save the same images after encoding / decoding -mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) -mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) +mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) +z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) +mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) save_image(x, "output.png") # Generate a bunch of images -z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) +z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) +mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) save_image(x, "synth.png") ###################################################################### -- 2.39.5