return input
-def degrade_input(input, mask_generate, noise_levels):
+def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
result = []
- for phi in noise_levels:
- mask_diffusion_noise = mask_generate * (r <= phi).long()
- x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+ for n in nb_iterations:
+ proba_erased = 1 - (1 - noise_proba) ** n
+ mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
+ x = (1 - mask_erased) * input + mask_erased * noise
result.append(x)
return result
model.optimizer.zero_grad()
deterministic = (
- mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+ mask_generate.sum(dim=1) < mask_generate.size(1) // 4
).long()
- k = torch.randint(3, (input.size(0), 1), device=input.device)
- phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
- phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+ N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+ N1 = N0 + 1
- targets, input = degrade_input(input, mask_generate, (phi0, phi1))
+ N0 = (1 - deterministic) * N0
+ N1 = deterministic * nb_iterations + (1 - deterministic) * N1
+
+ # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}")
+
+ targets, input = degrade_input(input, mask_generate, (N0, N1))
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
"test",
):
deterministic = (
- mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+ mask_generate.sum(dim=1) < mask_generate.size(1) // 4
).long()
- k = torch.randint(3, (input.size(0), 1), device=input.device)
- phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
- phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+ N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+ N1 = N0 + 1
+
+ N0 = (1 - deterministic) * N0
+ N1 = deterministic * nb_iterations + (1 - deterministic) * N1
- phi = torch.rand((input.size(0), 1), device=input.device)
- phi = deterministic + (1 - deterministic) * phi
- targets, input = degrade_input(input, mask_generate, (phi0, phi1))
+ targets, input = degrade_input(input, mask_generate, (N0, N1))
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)