-def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b):
- mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1)
- mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
+ mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
+ mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)