From c9b85b67c983916d809e185ef12dec8b76e0aa68 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 4 Mar 2024 07:20:22 +0100 Subject: [PATCH] Update. --- tiny_vae.py | 95 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/tiny_vae.py b/tiny_vae.py index 405c103..10ce19f 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -30,7 +30,7 @@ parser = argparse.ArgumentParser( parser.add_argument("--nb_epochs", type=int, default=100) -parser.add_argument("--learning_rate", type=float, default=2e-4) +parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--batch_size", type=int, default=100) @@ -40,10 +40,12 @@ parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--latent_dim", type=int, default=32) -parser.add_argument("--nb_channels", type=int, default=128) +parser.add_argument("--nb_channels", type=int, default=64) parser.add_argument("--no_dkl", action="store_true") +parser.add_argument("--beta", type=float, default=1.0) + args = parser.parse_args() log_file = open(args.log_filename, "w") @@ -157,6 +159,55 @@ test_input = test_set.data.view(-1, 1, 28, 28).float() ###################################################################### + +def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""): + def save_image(x, filename): + x = x * train_std + train_mu + x = x.clamp(min=0, max=255) / 255 + torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + log_string(f"wrote {filename}") + + # Save a bunch of train images + + x = train_input[:256] + save_image(x, f"{prefix}train_input.png") + + # Save the same images after encoding / decoding + + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}train_output.png") + save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png") + + # Save a bunch of test images + + x = test_input[:256] + save_image(x, f"{prefix}input.png") + + # Save the same images after encoding / decoding + + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}output.png") + save_image(param_p_X_given_z[0], f"{prefix}output_mean.png") + + # Generate a bunch of images + + z = sample_gaussian( + (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) + ) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}synth.png") + save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png") + + +###################################################################### + model_q_Z_given_x = LatentGivenImageNet( nb_channels=args.nb_channels, latent_dim=args.latent_dim ) @@ -187,7 +238,7 @@ zeros = train_input.new_zeros(1, args.latent_dim) param_p_Z = zeros, zeros -for epoch in range(args.nb_epochs): +for n_epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): @@ -203,7 +254,7 @@ for epoch in range(args.nb_epochs): loss = -(log_p_x_z - log_q_z_given_x).mean() else: dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) - loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() + loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() loss.backward() @@ -211,39 +262,9 @@ for epoch in range(args.nb_epochs): acc_loss += loss.item() * x.size(0) - log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}") - -###################################################################### - - -def save_image(x, filename): - x = x * train_std + train_mu - x = x.clamp(min=0, max=255) / 255 - torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}") - -# Save a bunch of test images - -x = test_input[:256] -save_image(x, "input.png") - -# Save the same images after encoding / decoding - -param_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(param_q_Z_given_x) -param_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(param_p_X_given_z) -save_image(x, "output.png") -save_image(param_p_X_given_z[0], "output_mean.png") - -# Generate a bunch of images - -z = sample_gaussian( - (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) -) -param_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(param_p_X_given_z) -save_image(x, "synth.png") -save_image(param_p_X_given_z[0], "synth_mean.png") + if (n_epoch + 1) % 25 == 0: + save_images(model_q_Z_given_x, model_p_X_given_z, f"epoch_{n_epoch+1:04d}_") ###################################################################### -- 2.39.5