X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tiny_vae.py;h=4d11c7f41c80d42903893cb2b756bb9392c8e79c;hp=bbdbf1a734cb8379cb746aa217d50cb8435f977c;hb=HEAD;hpb=c951f0b1b425dc91ba74e9cb75425b0ad2f481ac diff --git a/tiny_vae.py b/tiny_vae.py index bbdbf1a..4d11c7f 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -11,7 +11,7 @@ # Written by Francois Fleuret -import sys, os, argparse, time, math, itertools +import sys, os, argparse, time, math import torch, torchvision @@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") +parser = argparse.ArgumentParser( + description="Very simple implementation of a VAE for teaching." +) parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--learning_rate", type=float, default=1e-3) + parser.add_argument("--batch_size", type=int, default=100) parser.add_argument("--data_dir", type=str, default="./data/") @@ -36,10 +40,12 @@ parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--latent_dim", type=int, default=32) -parser.add_argument("--nb_channels", type=int, default=128) +parser.add_argument("--nb_channels", type=int, default=32) parser.add_argument("--no_dkl", action="store_true") +parser.add_argument("--beta", type=float, default=1.0) + args = parser.parse_args() log_file = open(args.log_filename, "w") @@ -48,7 +54,7 @@ log_file = open(args.log_filename, "w") def log_string(s): - t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + t = time.strftime("%Y-%m-%d_%H:%M:%S ", time.localtime()) if log_file is not None: log_file.write(t + s + "\n") @@ -61,38 +67,53 @@ def log_string(s): ###################################################################### -def sample_gaussian(mu, log_var): +def sample_categorical(param): + dist = torch.distributions.Categorical(logits=param) + return (dist.sample().unsqueeze(1).float() - train_mu) / train_std + + +def log_p_categorical(x, param): + x = (x.squeeze(1) * train_std + train_mu).long().clamp(min=0, max=255) + param = param.permute(0, 3, 1, 2) + return -F.cross_entropy(param, x, reduction="none").flatten(1).sum(dim=1) + + +def sample_gaussian(param): + mean, log_var = param std = log_var.mul(0.5).exp() - return torch.randn(mu.size(), device=mu.device) * std + mu + return torch.randn(mean.size(), device=mean.device) * std + mean -def log_p_gaussian(x, mu, log_var): +def log_p_gaussian(x, param): + mean, log_var, x = param[0].flatten(1), param[1].flatten(1), x.flatten(1) var = log_var.exp() - return ( - (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) - .flatten(1) - .sum(1) - ) + return -0.5 * (((x - mean).pow(2) / var) + log_var + math.log(2 * math.pi)).sum(1) -def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b): - mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1) - mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(param_a, param_b): + mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1) + mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( - log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b + log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b ).sum(1) +def dup_param(param, nb): + mean, log_var = param + s = (nb,) + (-1,) * (mean.dim() - 1) + return (mean.expand(s), log_var.expand(s)) + + ###################################################################### -class LatentGivenImageNet(nn.Module): +class VariationalAutoEncoder(nn.Module): def __init__(self, nb_channels, latent_dim): super().__init__() - self.model = nn.Sequential( + self.encoder = nn.Sequential( nn.Conv2d(1, nb_channels, kernel_size=1), # to 28x28 nn.ReLU(inplace=True), nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 24x24 @@ -106,17 +127,7 @@ class LatentGivenImageNet(nn.Module): nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4), ) - def forward(self, x): - output = self.model(x).view(x.size(0), 2, -1) - mu, log_var = output[:, 0], output[:, 1] - return mu, log_var - - -class ImageGivenLatentNet(nn.Module): - def __init__(self, nb_channels, latent_dim): - super().__init__() - - self.model = nn.Sequential( + self.decoder = nn.Sequential( nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4), nn.ReLU(inplace=True), nn.ConvTranspose2d( @@ -132,9 +143,18 @@ class ImageGivenLatentNet(nn.Module): nn.ConvTranspose2d(nb_channels, 2, kernel_size=5), # from 24x24 ) - def forward(self, z): - output = self.model(z.view(z.size(0), -1, 1, 1)) + def encode(self, x): + output = self.encoder(x).view(x.size(0), 2, -1) + mu, log_var = output[:, 0], output[:, 1] + return mu, log_var + + def decode(self, z): + # return self.decoder(z.view(z.size(0), -1, 1, 1)).permute(0, 2, 3, 1) + output = self.decoder(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] + log_var.flatten(1)[...] = 1 # math.log(1e-1) + # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1] + # log_var = log_var.clamp(min=2*math.log(1/256)) return mu, log_var @@ -150,21 +170,56 @@ test_input = test_set.data.view(-1, 1, 28, 28).float() ###################################################################### -model_q_Z_given_x = LatentGivenImageNet( - nb_channels=args.nb_channels, latent_dim=args.latent_dim -) -model_p_X_given_z = ImageGivenLatentNet( - nb_channels=args.nb_channels, latent_dim=args.latent_dim -) +def save_images(model, prefix=""): + def save_image(x, filename): + x = x * train_std + train_mu + x = x.clamp(min=0, max=255) / 255 + torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0) + log_string(f"wrote {filename}") -optimizer = optim.Adam( - itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()), - lr=4e-4, -) + # Save a bunch of train images + + x = train_input[:36] + save_image(x, f"{prefix}train_input.png") + + # Save the same images after encoding / decoding -model_p_X_given_z.to(device) -model_q_Z_given_x.to(device) + param_q_Z_given_x = model.encode(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model.decode(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}train_output.png") + save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png") + + # Save a bunch of test images + + x = test_input[:36] + save_image(x, f"{prefix}input.png") + + # Save the same images after encoding / decoding + + param_q_Z_given_x = model.encode(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model.decode(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}output.png") + save_image(param_p_X_given_z[0], f"{prefix}output_mean.png") + + # Generate a bunch of images + + z = sample_gaussian(dup_param(param_p_Z, x.size(0))) + param_p_X_given_z = model.decode(z) + x = sample_gaussian(param_p_X_given_z) + save_image(x, f"{prefix}synth.png") + save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png") + + +###################################################################### + +model = VariationalAutoEncoder(nb_channels=args.nb_channels, latent_dim=args.latent_dim) + +model.to(device) ###################################################################### @@ -176,29 +231,32 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mu_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mu_p_Z +zeros = train_input.new_zeros(1, args.latent_dim) + +param_p_Z = zeros, zeros + +for n_epoch in range(args.nb_epochs): + optimizer = optim.Adam( + model.parameters(), + lr=args.learning_rate, + ) -for epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) - mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + param_q_Z_given_x = model.encode(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model.decode(z) + log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x) - log_p_x_z = log_p_gaussian( - x, mu_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z) + log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x) + log_p_z = log_p_gaussian(z, param_p_Z) + log_p_x_z = log_p_x_given_z + log_p_z loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z) - dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z - ) - loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() + dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) + loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() loss.backward() @@ -206,35 +264,9 @@ for epoch in range(args.nb_epochs): acc_loss += loss.item() * x.size(0) - log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}") - -###################################################################### - - -def save_image(x, filename): - x = x * train_std + train_mu - x = x.clamp(min=0, max=255) / 255 - torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) - - -# Save a bunch of test images - -x = test_input[:256] -save_image(x, "input.png") - -# Save the same images after encoding / decoding - -mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) -mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) -save_image(x, "output.png") - -# Generate a bunch of images + log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}") -z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) -save_image(x, "synth.png") + if (n_epoch + 1) % 25 == 0: + save_images(model, f"epoch_{n_epoch+1:04d}_") ######################################################################