def deterministic(mask_generate):
- return mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+ return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
def ae_generate(
proba_erased = noise_proba
- d = deterministic(mask_generate)
+ d = deterministic(mask_generate)[:, None]
changed = True
+
for it in range(nb_iterations_max):
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
mask_to_change = d * mask_generate + (1 - d) * mask_erased
update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
+
if update.equal(input):
log_string(f"converged at iteration {it}")
break
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- d = deterministic = mask_generate
+ d = deterministic(mask_generate)
p = probs_iterations.expand(input.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=p)
N0 = dist.sample()
local_device,
"test",
):
- d = deterministic = mask_generate
+ d = deterministic(mask_generate)
p = probs_iterations.expand(input.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=p)
N0 = dist.sample()
input,
mask_generate,
n_epoch,
- nb_iterations,
noise_proba=noise_proba,
)
+
correct = (input == targets).min(dim=1).values.long()
predicted_parts = torch.tensor(quad_generate, device=input.device)
predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)