+
+def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""):
+ def save_image(x, filename):
+ x = x * train_std + train_mu
+ x = x.clamp(min=0, max=255) / 255
+ torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+ log_string(f"wrote {filename}")
+
+ # Save a bunch of train images
+
+ x = train_input[:256]
+ save_image(x, f"{prefix}train_input.png")
+
+ # Save the same images after encoding / decoding
+
+ param_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}train_output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
+
+ # Save a bunch of test images
+
+ x = test_input[:256]
+ save_image(x, f"{prefix}input.png")
+
+ # Save the same images after encoding / decoding
+
+ param_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
+
+ # Generate a bunch of images
+
+ z = sample_gaussian(
+ (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
+ )
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}synth.png")
+ save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
+
+
+######################################################################
+