Update.
[pytorch.git] / tiny_vae.py
index bbdbf1a..4d11c7f 100755 (executable)
@@ -11,7 +11,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import sys, os, argparse, time, math, itertools
+import sys, os, argparse, time, math
 
 import torch, torchvision
 
@@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
 
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+    description="Very simple implementation of a VAE for teaching."
+)
 
 parser.add_argument("--nb_epochs", type=int, default=25)
 
+parser.add_argument("--learning_rate", type=float, default=1e-3)
+
 parser.add_argument("--batch_size", type=int, default=100)
 
 parser.add_argument("--data_dir", type=str, default="./data/")
@@ -36,10 +40,12 @@ parser.add_argument("--log_filename", type=str, default="train.log")
 
 parser.add_argument("--latent_dim", type=int, default=32)
 
-parser.add_argument("--nb_channels", type=int, default=128)
+parser.add_argument("--nb_channels", type=int, default=32)
 
 parser.add_argument("--no_dkl", action="store_true")
 
+parser.add_argument("--beta", type=float, default=1.0)
+
 args = parser.parse_args()
 
 log_file = open(args.log_filename, "w")
@@ -48,7 +54,7 @@ log_file = open(args.log_filename, "w")
 
 
 def log_string(s):
-    t = time.strftime("%Y-%m-%d_%H:%M:%S ", time.localtime())
+    t = time.strftime("%Y-%m-%d_%H:%M:%S ", time.localtime())
 
     if log_file is not None:
         log_file.write(t + s + "\n")
@@ -61,38 +67,53 @@ def log_string(s):
 ######################################################################
 
 
-def sample_gaussian(mu, log_var):
+def sample_categorical(param):
+    dist = torch.distributions.Categorical(logits=param)
+    return (dist.sample().unsqueeze(1).float() - train_mu) / train_std
+
+
+def log_p_categorical(x, param):
+    x = (x.squeeze(1) * train_std + train_mu).long().clamp(min=0, max=255)
+    param = param.permute(0, 3, 1, 2)
+    return -F.cross_entropy(param, x, reduction="none").flatten(1).sum(dim=1)
+
+
+def sample_gaussian(param):
+    mean, log_var = param
     std = log_var.mul(0.5).exp()
-    return torch.randn(mu.size(), device=mu.device) * std + mu
+    return torch.randn(mean.size(), device=mean.device) * std + mean
 
 
-def log_p_gaussian(x, mu, log_var):
+def log_p_gaussian(x, param):
+    mean, log_var, x = param[0].flatten(1), param[1].flatten(1), x.flatten(1)
     var = log_var.exp()
-    return (
-        (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
-        .flatten(1)
-        .sum(1)
-    )
+    return -0.5 * (((x - mean).pow(2) / var) + log_var + math.log(2 * math.pi)).sum(1)
 
 
-def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b):
-    mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1)
-    mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(param_a, param_b):
+    mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1)
+    mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1)
     var_a = log_var_a.exp()
     var_b = log_var_b.exp()
     return 0.5 * (
-        log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b
+        log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
     ).sum(1)
 
 
+def dup_param(param, nb):
+    mean, log_var = param
+    s = (nb,) + (-1,) * (mean.dim() - 1)
+    return (mean.expand(s), log_var.expand(s))
+
+
 ######################################################################
 
 
-class LatentGivenImageNet(nn.Module):
+class VariationalAutoEncoder(nn.Module):
     def __init__(self, nb_channels, latent_dim):
         super().__init__()
 
-        self.model = nn.Sequential(
+        self.encoder = nn.Sequential(
             nn.Conv2d(1, nb_channels, kernel_size=1),  # to 28x28
             nn.ReLU(inplace=True),
             nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 24x24
@@ -106,17 +127,7 @@ class LatentGivenImageNet(nn.Module):
             nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
         )
 
-    def forward(self, x):
-        output = self.model(x).view(x.size(0), 2, -1)
-        mu, log_var = output[:, 0], output[:, 1]
-        return mu, log_var
-
-
-class ImageGivenLatentNet(nn.Module):
-    def __init__(self, nb_channels, latent_dim):
-        super().__init__()
-
-        self.model = nn.Sequential(
+        self.decoder = nn.Sequential(
             nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
             nn.ReLU(inplace=True),
             nn.ConvTranspose2d(
@@ -132,9 +143,18 @@ class ImageGivenLatentNet(nn.Module):
             nn.ConvTranspose2d(nb_channels, 2, kernel_size=5),  # from 24x24
         )
 
-    def forward(self, z):
-        output = self.model(z.view(z.size(0), -1, 1, 1))
+    def encode(self, x):
+        output = self.encoder(x).view(x.size(0), 2, -1)
+        mu, log_var = output[:, 0], output[:, 1]
+        return mu, log_var
+
+    def decode(self, z):
+        # return self.decoder(z.view(z.size(0), -1, 1, 1)).permute(0, 2, 3, 1)
+        output = self.decoder(z.view(z.size(0), -1, 1, 1))
         mu, log_var = output[:, 0:1], output[:, 1:2]
+        log_var.flatten(1)[...] = 1  # math.log(1e-1)
+        # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1]
+        # log_var = log_var.clamp(min=2*math.log(1/256))
         return mu, log_var
 
 
@@ -150,21 +170,56 @@ test_input = test_set.data.view(-1, 1, 28, 28).float()
 
 ######################################################################
 
-model_q_Z_given_x = LatentGivenImageNet(
-    nb_channels=args.nb_channels, latent_dim=args.latent_dim
-)
 
-model_p_X_given_z = ImageGivenLatentNet(
-    nb_channels=args.nb_channels, latent_dim=args.latent_dim
-)
+def save_images(model, prefix=""):
+    def save_image(x, filename):
+        x = x * train_std + train_mu
+        x = x.clamp(min=0, max=255) / 255
+        torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0)
+        log_string(f"wrote {filename}")
 
-optimizer = optim.Adam(
-    itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
-    lr=4e-4,
-)
+    # Save a bunch of train images
+
+    x = train_input[:36]
+    save_image(x, f"{prefix}train_input.png")
+
+    # Save the same images after encoding / decoding
 
-model_p_X_given_z.to(device)
-model_q_Z_given_x.to(device)
+    param_q_Z_given_x = model.encode(x)
+    z = sample_gaussian(param_q_Z_given_x)
+    param_p_X_given_z = model.decode(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}train_output.png")
+    save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
+
+    # Save a bunch of test images
+
+    x = test_input[:36]
+    save_image(x, f"{prefix}input.png")
+
+    # Save the same images after encoding / decoding
+
+    param_q_Z_given_x = model.encode(x)
+    z = sample_gaussian(param_q_Z_given_x)
+    param_p_X_given_z = model.decode(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}output.png")
+    save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
+
+    # Generate a bunch of images
+
+    z = sample_gaussian(dup_param(param_p_Z, x.size(0)))
+    param_p_X_given_z = model.decode(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}synth.png")
+    save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
+
+
+######################################################################
+
+model = VariationalAutoEncoder(nb_channels=args.nb_channels, latent_dim=args.latent_dim)
+
+model.to(device)
 
 ######################################################################
 
@@ -176,29 +231,32 @@ test_input.sub_(train_mu).div_(train_std)
 
 ######################################################################
 
-mu_p_Z = train_input.new_zeros(1, args.latent_dim)
-log_var_p_Z = mu_p_Z
+zeros = train_input.new_zeros(1, args.latent_dim)
+
+param_p_Z = zeros, zeros
+
+for n_epoch in range(args.nb_epochs):
+    optimizer = optim.Adam(
+        model.parameters(),
+        lr=args.learning_rate,
+    )
 
-for epoch in range(args.nb_epochs):
     acc_loss = 0
 
     for x in train_input.split(args.batch_size):
-        mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-        z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
-        mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+        param_q_Z_given_x = model.encode(x)
+        z = sample_gaussian(param_q_Z_given_x)
+        param_p_X_given_z = model.decode(z)
+        log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z)
 
         if args.no_dkl:
-            log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x)
-            log_p_x_z = log_p_gaussian(
-                x, mu_p_X_given_z, log_var_p_X_given_z
-            ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z)
+            log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x)
+            log_p_z = log_p_gaussian(z, param_p_Z)
+            log_p_x_z = log_p_x_given_z + log_p_z
             loss = -(log_p_x_z - log_q_z_given_x).mean()
         else:
-            log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z)
-            dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
-                mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z
-            )
-            loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
+            dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
+            loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean()
 
         optimizer.zero_grad()
         loss.backward()
@@ -206,35 +264,9 @@ for epoch in range(args.nb_epochs):
 
         acc_loss += loss.item() * x.size(0)
 
-    log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
-
-######################################################################
-
-
-def save_image(x, filename):
-    x = x * train_std + train_mu
-    x = x.clamp(min=0, max=255) / 255
-    torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
-
-
-# Save a bunch of test images
-
-x = test_input[:256]
-save_image(x, "input.png")
-
-# Save the same images after encoding / decoding
-
-mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
-mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
-save_image(x, "output.png")
-
-# Generate a bunch of images
+    log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}")
 
-z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
-mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
-save_image(x, "synth.png")
+    if (n_epoch + 1) % 25 == 0:
+        save_images(model, f"epoch_{n_epoch+1:04d}_")
 
 ######################################################################