From 8433ea2f5781449b1bc010bd08bd97acef7f7ae4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Sep 2024 11:45:11 +0200 Subject: [PATCH] Update. --- main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 92a34f1..5c086cf 100755 --- a/main.py +++ b/main.py @@ -731,10 +731,12 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None single_iteration = deterministic(mask_generate)[:, None] - if mask_hints is not None: - mask_generate = mask_generate * (1 - mask_hints) + if mask_hints is None: + mask_start = mask_generate + else: + mask_start = mask_generate * (1 - mask_hints) - x_t = (1 - mask_generate) * x_0 + mask_generate * noise + x_t = (1 - mask_start) * x_0 + mask_start * noise changed = True -- 2.39.5