Update.
[pytorch.git] / tiny_vae.py
index bbdbf1a..d33cc4b 100755 (executable)
@@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
 
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+    description="Very simple implementation of a VAE for teaching."
+)
 
 parser.add_argument("--nb_epochs", type=int, default=25)
 
+parser.add_argument("--learning_rate", type=float, default=1e-3)
+
 parser.add_argument("--batch_size", type=int, default=100)
 
 parser.add_argument("--data_dir", type=str, default="./data/")
@@ -40,6 +44,11 @@ parser.add_argument("--nb_channels", type=int, default=128)
 
 parser.add_argument("--no_dkl", action="store_true")
 
+# With that option, do not follow the setup of the original VAE paper
+# of forcing the variance of X|Z to 1 during training and to 0 for
+# sampling, but optimize and use the variance.
+parser.add_argument("--no_hacks", action="store_true")
+
 args = parser.parse_args()
 
 log_file = open(args.log_filename, "w")
@@ -75,13 +84,13 @@ def log_p_gaussian(x, mu, log_var):
     )
 
 
-def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b):
-    mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1)
-    mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
+    mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
+    mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)
     var_a = log_var_a.exp()
     var_b = log_var_b.exp()
     return 0.5 * (
-        log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b
+        log_var_b - log_var_a - 1 + (mean_a - mean_b).pow(2) / var_b + var_a / var_b
     ).sum(1)
 
 
@@ -135,6 +144,8 @@ class ImageGivenLatentNet(nn.Module):
     def forward(self, z):
         output = self.model(z.view(z.size(0), -1, 1, 1))
         mu, log_var = output[:, 0:1], output[:, 1:2]
+        if not args.no_hacks:
+            log_var[...] = 0
         return mu, log_var
 
 
@@ -160,7 +171,7 @@ model_p_X_given_z = ImageGivenLatentNet(
 
 optimizer = optim.Adam(
     itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
-    lr=4e-4,
+    lr=args.learning_rate,
 )
 
 model_p_X_given_z.to(device)
@@ -176,27 +187,27 @@ test_input.sub_(train_mu).div_(train_std)
 
 ######################################################################
 
-mu_p_Z = train_input.new_zeros(1, args.latent_dim)
-log_var_p_Z = mu_p_Z
+mean_p_Z = train_input.new_zeros(1, args.latent_dim)
+log_var_p_Z = mean_p_Z
 
 for epoch in range(args.nb_epochs):
     acc_loss = 0
 
     for x in train_input.split(args.batch_size):
-        mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-        z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
-        mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+        mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+        z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
+        mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
 
         if args.no_dkl:
-            log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x)
+            log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x)
             log_p_x_z = log_p_gaussian(
-                x, mu_p_X_given_z, log_var_p_X_given_z
-            ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z)
+                x, mean_p_X_given_z, log_var_p_X_given_z
+            ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z)
             loss = -(log_p_x_z - log_q_z_given_x).mean()
         else:
-            log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z)
+            log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z)
             dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
-                mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z
+                mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z
             )
             loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
 
@@ -224,17 +235,23 @@ save_image(x, "input.png")
 
 # Save the same images after encoding / decoding
 
-mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
-mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
+mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
+mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+if args.no_hacks:
+    x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+else:
+    x = mean_p_X_given_z
 save_image(x, "output.png")
 
 # Generate a bunch of images
 
-z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
-mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
+z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
+mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+if args.no_hacks:
+    x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+else:
+    x = mean_p_X_given_z
 save_image(x, "synth.png")
 
 ######################################################################