description="Very simple implementation of a VAE for teaching."
)
-parser.add_argument("--nb_epochs", type=int, default=25)
+parser.add_argument("--nb_epochs", type=int, default=100)
-parser.add_argument("--learning_rate", type=float, default=1e-3)
+parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--latent_dim", type=int, default=32)
-parser.add_argument("--nb_channels", type=int, default=128)
+parser.add_argument("--nb_channels", type=int, default=64)
parser.add_argument("--no_dkl", action="store_true")
-# With that option, do not follow the setup of the original VAE paper
-# of forcing the variance of X|Z to 1 during training and to 0 for
-# sampling, but optimize and use the variance.
-
-parser.add_argument("--no_hacks", action="store_true")
+parser.add_argument("--beta", type=float, default=1.0)
args = parser.parse_args()
######################################################################
-def sample_gaussian(mu, log_var):
+def sample_gaussian(param):
+ mu, log_var = param
std = log_var.mul(0.5).exp()
return torch.randn(mu.size(), device=mu.device) * std + mu
-def log_p_gaussian(x, mu, log_var):
+def log_p_gaussian(x, param):
+ mu, log_var = param
var = log_var.exp()
return (
(-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
)
-def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
- mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
- mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(param_a, param_b):
+ mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1)
+ mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1)
var_a = log_var_a.exp()
var_b = log_var_b.exp()
return 0.5 * (
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
- if not args.no_hacks:
- log_var[...] = 0
+ # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1]
return mu, log_var
######################################################################
+
+def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""):
+ def save_image(x, filename):
+ x = x * train_std + train_mu
+ x = x.clamp(min=0, max=255) / 255
+ torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+ log_string(f"wrote {filename}")
+
+ # Save a bunch of train images
+
+ x = train_input[:256]
+ save_image(x, f"{prefix}train_input.png")
+
+ # Save the same images after encoding / decoding
+
+ param_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}train_output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
+
+ # Save a bunch of test images
+
+ x = test_input[:256]
+ save_image(x, f"{prefix}input.png")
+
+ # Save the same images after encoding / decoding
+
+ param_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}output.png")
+ save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
+
+ # Generate a bunch of images
+
+ z = sample_gaussian(
+ (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
+ )
+ param_p_X_given_z = model_p_X_given_z(z)
+ x = sample_gaussian(param_p_X_given_z)
+ save_image(x, f"{prefix}synth.png")
+ save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
+
+
+######################################################################
+
model_q_Z_given_x = LatentGivenImageNet(
nb_channels=args.nb_channels, latent_dim=args.latent_dim
)
######################################################################
-mean_p_Z = train_input.new_zeros(1, args.latent_dim)
-log_var_p_Z = mean_p_Z
+zeros = train_input.new_zeros(1, args.latent_dim)
+
+param_p_Z = zeros, zeros
-for epoch in range(args.nb_epochs):
+for n_epoch in range(args.nb_epochs):
acc_loss = 0
for x in train_input.split(args.batch_size):
- mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
- z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
- mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+ param_q_Z_given_x = model_q_Z_given_x(x)
+ z = sample_gaussian(param_q_Z_given_x)
+ param_p_X_given_z = model_p_X_given_z(z)
+ log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z)
if args.no_dkl:
- log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x)
- log_p_x_z = log_p_gaussian(
- x, mean_p_X_given_z, log_var_p_X_given_z
- ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z)
+ log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x)
+ log_p_z = log_p_gaussian(z, param_p_Z)
+ log_p_x_z = log_p_x_given_z + log_p_x_z
loss = -(log_p_x_z - log_q_z_given_x).mean()
else:
- log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z)
- dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
- mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z
- )
- loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
+ dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
+ loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean()
optimizer.zero_grad()
loss.backward()
acc_loss += loss.item() * x.size(0)
- log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
-
-######################################################################
-
-
-def save_image(x, filename):
- x = x * train_std + train_mu
- x = x.clamp(min=0, max=255) / 255
- torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
-
-
-# Save a bunch of test images
-
-x = test_input[:256]
-save_image(x, "input.png")
-
-# Save the same images after encoding / decoding
-
-mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
-mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
- x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
- x = mean_p_X_given_z
-save_image(x, "output.png")
-
-# Generate a bunch of images
+ log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}")
-z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
-mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-if args.no_hacks:
- x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
-else:
- x = mean_p_X_given_z
-save_image(x, "synth.png")
+ if (n_epoch + 1) % 25 == 0:
+ save_images(model_q_Z_given_x, model_p_X_given_z, f"epoch_{n_epoch+1:04d}_")
######################################################################