From 80b028965bdbccf7b9d586af5f5f2d3ec41e9432 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 26 Aug 2024 20:11:50 +0200 Subject: [PATCH] Update. --- main.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 3c8b4b9..eb0f776 100755 --- a/main.py +++ b/main.py @@ -892,6 +892,15 @@ def deterministic(mask_generate): return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() +def prioritized_rand(low): + x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values + k = torch.rand(low.size(), device=low.device) + low.long() + k = k.sort(dim=1).indices + y = x.new(x.size()) + y.scatter_(dim=1, index=k, src=x) + return y + + def ae_generate( model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50 ): @@ -909,20 +918,23 @@ def ae_generate( input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) + final = dist.sample() + + r = prioritized_rand(final != input) - r = torch.rand(mask_generate.size(), device=mask_generate.device) mask_erased = mask_generate * (r <= proba_erased).long() mask_to_change = d * mask_generate + (1 - d) * mask_erased - update = (1 - mask_to_change) * input + mask_to_change * dist.sample() + update = (1 - mask_to_change) * input + mask_to_change * final if update.equal(input): - log_string(f"converged at iteration {it}") break else: changed = changed & (update != input).max(dim=1).values input[changed] = update[changed] + log_string(f"remains {changed.long().sum()}") + return input -- 2.39.5