X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=tiny_vae.py;h=405c103ca8304098e5380e85508274cc9e3606df;hb=ae1c9180165d9264e9cfe152bb64164926b5ddd2;hp=d33cc4bbf2cec2833f66d8c85fa0ac5d3063b4e2;hpb=9529aef0fdc1d1810e3c331a01fdbe640669ecfc;p=pytorch.git diff --git a/tiny_vae.py b/tiny_vae.py index d33cc4b..405c103 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -28,9 +28,9 @@ parser = argparse.ArgumentParser( description="Very simple implementation of a VAE for teaching." ) -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=100) -parser.add_argument("--learning_rate", type=float, default=1e-3) +parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--batch_size", type=int, default=100) @@ -44,11 +44,6 @@ parser.add_argument("--nb_channels", type=int, default=128) parser.add_argument("--no_dkl", action="store_true") -# With that option, do not follow the setup of the original VAE paper -# of forcing the variance of X|Z to 1 during training and to 0 for -# sampling, but optimize and use the variance. -parser.add_argument("--no_hacks", action="store_true") - args = parser.parse_args() log_file = open(args.log_filename, "w") @@ -70,12 +65,14 @@ def log_string(s): ###################################################################### -def sample_gaussian(mu, log_var): +def sample_gaussian(param): + mu, log_var = param std = log_var.mul(0.5).exp() return torch.randn(mu.size(), device=mu.device) * std + mu -def log_p_gaussian(x, mu, log_var): +def log_p_gaussian(x, param): + mu, log_var = param var = log_var.exp() return ( (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) @@ -84,9 +81,9 @@ def log_p_gaussian(x, mu, log_var): ) -def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b): - mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1) - mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(param_a, param_b): + mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1) + mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( @@ -144,8 +141,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] - if not args.no_hacks: - log_var[...] = 0 + # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1] return mu, log_var @@ -187,28 +183,26 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mean_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mean_p_Z +zeros = train_input.new_zeros(1, args.latent_dim) + +param_p_Z = zeros, zeros for epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) - mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x) - log_p_x_z = log_p_gaussian( - x, mean_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z) + log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x) + log_p_z = log_p_gaussian(z, param_p_Z) + log_p_x_z = log_p_x_given_z + log_p_x_z loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z) - dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z - ) + dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() @@ -235,23 +229,21 @@ save_image(x, "input.png") # Save the same images after encoding / decoding -mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z +param_q_Z_given_x = model_q_Z_given_x(x) +z = sample_gaussian(param_q_Z_given_x) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "output.png") +save_image(param_p_X_given_z[0], "output_mean.png") # Generate a bunch of images -z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -if args.no_hacks: - x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) -else: - x = mean_p_X_given_z +z = sample_gaussian( + (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) +) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "synth.png") +save_image(param_p_X_given_z[0], "synth_mean.png") ######################################################################