return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
def ae_generate(
model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
):
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
+ final = dist.sample()
+
+ r = prioritized_rand(final != input)
- r = torch.rand(mask_generate.size(), device=mask_generate.device)
mask_erased = mask_generate * (r <= proba_erased).long()
mask_to_change = d * mask_generate + (1 - d) * mask_erased
- update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
+ update = (1 - mask_to_change) * input + mask_to_change * final
if update.equal(input):
- log_string(f"converged at iteration {it}")
break
else:
changed = changed & (update != input).max(dim=1).values
input[changed] = update[changed]
+ log_string(f"remains {changed.long().sum()}")
+
return input