X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tiny_vae.py;h=405c103ca8304098e5380e85508274cc9e3606df;hb=ae1c9180165d9264e9cfe152bb64164926b5ddd2;hp=b81df9a79807da532d737227321758ede5272248;hpb=4ddaff76d908e7f8444d2fb08e4c0afdabe32ea5;p=pytorch.git diff --git a/tiny_vae.py b/tiny_vae.py index b81df9a..405c103 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -141,7 +141,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] - # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1] + # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1] return mu, log_var @@ -234,14 +234,16 @@ z = sample_gaussian(param_q_Z_given_x) param_p_X_given_z = model_p_X_given_z(z) x = sample_gaussian(param_p_X_given_z) save_image(x, "output.png") +save_image(param_p_X_given_z[0], "output_mean.png") # Generate a bunch of images z = sample_gaussian( - param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1) + (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) ) param_p_X_given_z = model_p_X_given_z(z) x = sample_gaussian(param_p_X_given_z) save_image(x, "synth.png") +save_image(param_p_X_given_z[0], "synth_mean.png") ######################################################################