Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 3 Mar 2024 11:17:56 +0000 (12:17 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 3 Mar 2024 11:17:56 +0000 (12:17 +0100)
tiny_vae.py

index b81df9a..cba42e1 100755 (executable)
@@ -141,7 +141,7 @@ class ImageGivenLatentNet(nn.Module):
     def forward(self, z):
         output = self.model(z.view(z.size(0), -1, 1, 1))
         mu, log_var = output[:, 0:1], output[:, 1:2]
-        # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
+        log_var.flatten(1)[...] = log_var.flatten(1)[:, :1]
         return mu, log_var
 
 
@@ -234,14 +234,16 @@ z = sample_gaussian(param_q_Z_given_x)
 param_p_X_given_z = model_p_X_given_z(z)
 x = sample_gaussian(param_p_X_given_z)
 save_image(x, "output.png")
+save_image(param_p_X_given_z[0], "output_mean.png")
 
 # Generate a bunch of images
 
 z = sample_gaussian(
-    param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)
+    (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
 )
 param_p_X_given_z = model_p_X_given_z(z)
 x = sample_gaussian(param_p_X_given_z)
 save_image(x, "synth.png")
+save_image(param_p_X_given_z[0], "synth_mean.png")
 
 ######################################################################