######################################################################
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+ description="Very simple implementation of a VAE for teaching."
+)
+
+parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument("--nb_epochs", type=int, default=100)
+parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--no_dkl", action="store_true")
+# With that option, do not follow the setup of the original VAE paper
+# of forcing the variance of X|Z to 1 during training and to 0 for
+# sampling, but optimize and use the variance.
+parser.add_argument("--no_hacks", action="store_true")
+
args = parser.parse_args()
log_file = open(args.log_filename, "w")
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
+ if not args.no_hacks:
+ log_var[...] = 0
return mu, log_var
optimizer = optim.Adam(
itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
- lr=4e-4,
+ lr=args.learning_rate,
)
model_p_X_given_z.to(device)
mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+if args.no_hacks:
+ x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+else:
+ x = mean_p_X_given_z
save_image(x, "output.png")
# Generate a bunch of images
z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+if args.no_hacks:
+ x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+else:
+ x = mean_p_X_given_z
save_image(x, "synth.png")
######################################################################