######################################################################
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+ description="Very simple implementation of a VAE for teaching."
+)
parser.add_argument("--nb_epochs", type=int, default=100)
+parser.add_argument("--learning_rate", type=float, default=2e-4)
+
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--data_dir", type=str, default="./data/")
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
+ # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
return mu, log_var
optimizer = optim.Adam(
itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
- lr=4e-4,
+ lr=args.learning_rate,
)
model_p_X_given_z.to(device)