+ param_q_Z_given_x = model.encode(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model.decode(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}train_output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
+
+ # Save a bunch of test images
+
+ x = test_input[:36]
+ save_image(x, f"{prefix}input.png")
+
+ # Save the same images after encoding / decoding
+
+ param_q_Z_given_x = model.encode(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model.decode(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
+
+ # Generate a bunch of images
+
+ z = sample_gaussian(dup_param(param_p_Z, x.size(0)))
+ param_p_X_given_z = model.decode(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}synth.png")
+ save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
+
+
+######################################################################
+
+model = VariationalAutoEncoder(nb_channels=args.nb_channels, latent_dim=args.latent_dim)
+
+model.to(device)