X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tiny_vae.py;h=10ce19fe9a691ac779f3507e12e249e56e4fc33e;hb=c9b85b67c983916d809e185ef12dec8b76e0aa68;hp=0895830a13d123a511c12c6b37e6ff414936f39e;hpb=39ce2c3b3e9f4a0d4da6cb7eb84d15cb03c55ff4;p=pytorch.git diff --git a/tiny_vae.py b/tiny_vae.py index 0895830..10ce19f 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -28,9 +28,9 @@ parser = argparse.ArgumentParser( description="Very simple implementation of a VAE for teaching." ) -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=100) -parser.add_argument("--learning_rate", type=float, default=1e-3) +parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--batch_size", type=int, default=100) @@ -40,15 +40,11 @@ parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--latent_dim", type=int, default=32) -parser.add_argument("--nb_channels", type=int, default=128) +parser.add_argument("--nb_channels", type=int, default=64) parser.add_argument("--no_dkl", action="store_true") -# With that option, do not follow the setup of the original VAE paper -# of forcing the variance of X|Z to 1 during training and to 0 for -# sampling, but optimize and use the variance. - -parser.add_argument("--no_hacks", action="store_true") +parser.add_argument("--beta", type=float, default=1.0) args = parser.parse_args() @@ -71,12 +67,14 @@ def log_string(s): ###################################################################### -def sample_gaussian(mu, log_var): +def sample_gaussian(param): + mu, log_var = param std = log_var.mul(0.5).exp() return torch.randn(mu.size(), device=mu.device) * std + mu -def log_p_gaussian(x, mu, log_var): +def log_p_gaussian(x, param): + mu, log_var = param var = log_var.exp() return ( (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) @@ -85,9 +83,9 @@ def log_p_gaussian(x, mu, log_var): ) -def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b): - mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1) - mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(param_a, param_b): + mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1) + mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( @@ -145,8 +143,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] - if not args.no_hacks: - log_var[...] = 0 + # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1] return mu, log_var @@ -162,6 +159,55 @@ test_input = test_set.data.view(-1, 1, 28, 28).float() ###################################################################### + +def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""): + def save_image(x, filename): + x = x * train_std + train_mu + x = x.clamp(min=0, max=255) / 255 + torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + log_string(f"wrote {filename}") + + # Save a bunch of train images + + x = train_input[:256] + save_image(x, f"{prefix}train_input.png") + + # Save the same images after encoding / decoding + + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}train_output.png") + save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png") + + # Save a bunch of test images + + x = test_input[:256] + save_image(x, f"{prefix}input.png") + + # Save the same images after encoding / decoding + + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}output.png") + save_image(param_p_X_given_z[0], f"{prefix}output_mean.png") + + # Generate a bunch of images + + z = sample_gaussian( + (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) + ) + param_p_X_given_z = model_p_X_given_z(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}synth.png") + save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png") + + +###################################################################### + model_q_Z_given_x = LatentGivenImageNet( nb_channels=args.nb_channels, latent_dim=args.latent_dim ) @@ -188,29 +234,27 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mean_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mean_p_Z +zeros = train_input.new_zeros(1, args.latent_dim) + +param_p_Z = zeros, zeros -for epoch in range(args.nb_epochs): +for n_epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) - mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x) - log_p_x_z = log_p_gaussian( - x, mean_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z) + log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x) + log_p_z = log_p_gaussian(z, param_p_Z) + log_p_x_z = log_p_x_given_z + log_p_x_z loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z) - dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z - ) - loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() + dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) + loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() loss.backward() @@ -218,41 +262,9 @@ for epoch in range(args.nb_epochs): acc_loss += loss.item() * x.size(0) - log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}") - -###################################################################### - - -def save_image(x, filename): - x = x * train_std + train_mu - x = x.clamp(min=0, max=255) / 255 - torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) - - -# Save a bunch of test images - -x = test_input[:256] -save_image(x, "input.png") - -# Save the same images after encoding / decoding - -mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z -save_image(x, "output.png") - -# Generate a bunch of images + log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}") -z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z -save_image(x, "synth.png") + if (n_epoch + 1) % 25 == 0: + save_images(model_q_Z_given_x, model_p_X_given_z, f"epoch_{n_epoch+1:04d}_") ######################################################################