From 95972e8f683616155b2b9fd332312249918ab000 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 13:09:55 +0200 Subject: [PATCH] Update. --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 4c30771..21c609c 100755 --- a/main.py +++ b/main.py @@ -819,7 +819,7 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 - return result + return x_t_minus_1 ###################################################################### @@ -888,7 +888,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - hat_x_t_minus_1 = one_iteration_prediction * x_0 + ( + hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( 1 - one_iteration_prediction ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) @@ -913,7 +913,7 @@ def model_ae_proba_solutions(model, input, log_proba=False): for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( - quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) logits = logits_hat_x_0_from_random_iteration( model, x_0, mask_generate, prompt_noise=args.prompt_noise @@ -939,7 +939,7 @@ def model_ae_argmax_nb_disagreements(model, input): nb_disagreements = 0 for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( - quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) logits = logits_hat_x_0_from_random_iteration( model, x_0, mask_generate, prompt_noise=args.prompt_noise @@ -966,7 +966,7 @@ def model_ae_argmax_predictions(model, input): for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)): for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( - quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) logits = logits_hat_x_0_from_random_iteration( model, x_0, mask_generate, prompt_noise=args.prompt_noise -- 2.39.5