Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 1 Mar 2024 23:36:04 +0000 (00:36 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 1 Mar 2024 23:36:04 +0000 (00:36 +0100)
tiny_vae.py

index 0895830..784f775 100755 (executable)
@@ -28,9 +28,9 @@ parser = argparse.ArgumentParser(
     description="Very simple implementation of a VAE for teaching."
 )
 
-parser.add_argument("--nb_epochs", type=int, default=25)
+parser.add_argument("--nb_epochs", type=int, default=100)
 
-parser.add_argument("--learning_rate", type=float, default=1e-3)
+parser.add_argument("--learning_rate", type=float, default=2e-4)
 
 parser.add_argument("--batch_size", type=int, default=100)
 
@@ -44,12 +44,6 @@ parser.add_argument("--nb_channels", type=int, default=128)
 
 parser.add_argument("--no_dkl", action="store_true")
 
-# With that option, do not follow the setup of the original VAE paper
-# of forcing the variance of X|Z to 1 during training and to 0 for
-# sampling, but optimize and use the variance.
-
-parser.add_argument("--no_hacks", action="store_true")
-
 args = parser.parse_args()
 
 log_file = open(args.log_filename, "w")
@@ -145,8 +139,7 @@ class ImageGivenLatentNet(nn.Module):
     def forward(self, z):
         output = self.model(z.view(z.size(0), -1, 1, 1))
         mu, log_var = output[:, 0:1], output[:, 1:2]
-        if not args.no_hacks:
-            log_var[...] = 0
+        # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
         return mu, log_var
 
 
@@ -239,20 +232,14 @@ save_image(x, "input.png")
 mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
 z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
-    x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
-    x = mean_p_X_given_z
+x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
 save_image(x, "output.png")
 
 # Generate a bunch of images
 
 z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
 mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
-    x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
-    x = mean_p_X_given_z
+x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
 save_image(x, "synth.png")
 
 ######################################################################