)
input = (1 - mask_generate) * input + mask_generate * noise
+ changed = True
for it in range(nb_iterations_max):
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
- pred_input = input.clone()
- input = (1 - mask_generate) * input + mask_generate * dist.sample()
- if (pred_input == input).min():
+ update = (1 - mask_generate) * input + mask_generate * dist.sample()
+ if update.equal(input):
break
+ else:
+ changed = changed & (update != input).max(dim=1).values
+ input[changed] = update[changed]
return input
-def degrade_input(input, mask_generate, *phis):
+def degrade_input(input, mask_generate, noise_levels):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
result = []
- for phi in phis:
+ for phi in noise_levels:
mask_diffusion_noise = mask_generate * (r <= phi).long()
x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
result.append(x)
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25)
- targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+ deterministic = (
+ mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+ ).long()
+
+ k = torch.randint(3, (input.size(0), 1), device=input.device)
+ phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
+ phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+
+ targets, input = degrade_input(input, mask_generate, (phi0, phi1))
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
local_device,
"test",
):
- phi = torch.rand((input.size(0), 1), device=input.device).clamp(
- min=0.25
- )
- targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+ deterministic = (
+ mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+ ).long()
+
+ k = torch.randint(3, (input.size(0), 1), device=input.device)
+ phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
+ phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+
+ phi = torch.rand((input.size(0), 1), device=input.device)
+ phi = deterministic + (1 - deterministic) * phi
+ targets, input = degrade_input(input, mask_generate, (phi0, phi1))
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)