X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tiny_vae.py;h=784f77502e34ca6ab3a2a9d153b94bfe85461301;hb=ce8deb0b3038c17d3198a447a7a3c44a93606d38;hp=577f717e8c61d14cb98a5f9a99ae1af22e90cdbc;hpb=dc1e3534151307491a1eacf053fc4aede631448b;p=pytorch.git diff --git a/tiny_vae.py b/tiny_vae.py index 577f717..784f775 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") +parser = argparse.ArgumentParser( + description="Very simple implementation of a VAE for teaching." +) parser.add_argument("--nb_epochs", type=int, default=100) +parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--batch_size", type=int, default=100) parser.add_argument("--data_dir", type=str, default="./data/") @@ -135,6 +139,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] + # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1] return mu, log_var @@ -160,7 +165,7 @@ model_p_X_given_z = ImageGivenLatentNet( optimizer = optim.Adam( itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()), - lr=4e-4, + lr=args.learning_rate, ) model_p_X_given_z.to(device)