x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
- return result
+ return x_t_minus_1
######################################################################
hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
- hat_x_t_minus_1 = one_iteration_prediction * x_0 + (
+ hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
1 - one_iteration_prediction
) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
- quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
logits = logits_hat_x_0_from_random_iteration(
model, x_0, mask_generate, prompt_noise=args.prompt_noise
nb_disagreements = 0
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
- quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
logits = logits_hat_x_0_from_random_iteration(
model, x_0, mask_generate, prompt_noise=args.prompt_noise
for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
- quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
logits = logits_hat_x_0_from_random_iteration(
model, x_0, mask_generate, prompt_noise=args.prompt_noise