Update.
[pytorch.git] / tiny_vae.py
index 577f717..784f775 100755 (executable)
@@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
 
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+    description="Very simple implementation of a VAE for teaching."
+)
 
 parser.add_argument("--nb_epochs", type=int, default=100)
 
+parser.add_argument("--learning_rate", type=float, default=2e-4)
+
 parser.add_argument("--batch_size", type=int, default=100)
 
 parser.add_argument("--data_dir", type=str, default="./data/")
@@ -135,6 +139,7 @@ class ImageGivenLatentNet(nn.Module):
     def forward(self, z):
         output = self.model(z.view(z.size(0), -1, 1, 1))
         mu, log_var = output[:, 0:1], output[:, 1:2]
+        # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
         return mu, log_var
 
 
@@ -160,7 +165,7 @@ model_p_X_given_z = ImageGivenLatentNet(
 
 optimizer = optim.Adam(
     itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
-    lr=4e-4,
+    lr=args.learning_rate,
 )
 
 model_p_X_given_z.to(device)