From 4ddaff76d908e7f8444d2fb08e4c0afdabe32ea5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 3 Mar 2024 08:23:00 +0100 Subject: [PATCH] Update. --- tiny_vae.py | 52 +++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tiny_vae.py b/tiny_vae.py index 784f775..b81df9a 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -65,12 +65,14 @@ def log_string(s): ###################################################################### -def sample_gaussian(mu, log_var): +def sample_gaussian(param): + mu, log_var = param std = log_var.mul(0.5).exp() return torch.randn(mu.size(), device=mu.device) * std + mu -def log_p_gaussian(x, mu, log_var): +def log_p_gaussian(x, param): + mu, log_var = param var = log_var.exp() return ( (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) @@ -79,9 +81,9 @@ def log_p_gaussian(x, mu, log_var): ) -def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b): - mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1) - mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(param_a, param_b): + mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1) + mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( @@ -181,28 +183,26 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mean_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mean_p_Z +zeros = train_input.new_zeros(1, args.latent_dim) + +param_p_Z = zeros, zeros for epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) - mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x) - log_p_x_z = log_p_gaussian( - x, mean_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z) + log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x) + log_p_z = log_p_gaussian(z, param_p_Z) + log_p_x_z = log_p_x_given_z + log_p_x_z loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z) - dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z - ) + dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() @@ -229,17 +229,19 @@ save_image(x, "input.png") # Save the same images after encoding / decoding -mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +param_q_Z_given_x = model_q_Z_given_x(x) +z = sample_gaussian(param_q_Z_given_x) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "output.png") # Generate a bunch of images -z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +z = sample_gaussian( + param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1) +) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "synth.png") ###################################################################### -- 2.39.5