return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
+def deterministic(mask_generate):
+ return mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+
+
+def ae_generate(
+ model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
+):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
input = (1 - mask_generate) * input + mask_generate * noise
+ proba_erased = noise_proba
+
+ d = deterministic(mask_generate)
changed = True
for it in range(nb_iterations_max):
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
- update = (1 - mask_generate) * input + mask_generate * dist.sample()
+
+ r = torch.rand(mask_generate.size(), device=mask_generate.device)
+ mask_erased = mask_generate * (r <= proba_erased).long()
+ mask_to_change = d * mask_generate + (1 - d) * mask_erased
+
+ update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
if update.equal(input):
break
else:
return input
-def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35):
+def degrade_input(input, mask_generate, nb_iterations, noise_proba):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
model.to(local_device).train()
optimizer_to(model.optimizer, local_device)
- nb_iterations = 10
+ nb_iterations = 25
+ probs_iterations = torch.arange(nb_iterations, device=main_device)
+ probs_iterations = 0.1 ** (probs_iterations / nb_iterations)
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
for n_epoch in range(args.nb_epochs):
# ----------------------
model.train()
nb_train_samples, acc_train_loss = 0, 0.0
+ noise_proba = 0.05
+
for input, mask_generate, mask_loss in ae_batches(
quiz_machine,
args.nb_train_samples,
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- deterministic = (
- mask_generate.sum(dim=1) < mask_generate.size(1) // 4
- ).long()
-
- N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+ d = deterministic = mask_generate
+ p = probs_iterations.expand(input.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=p)
+ N0 = dist.sample()
N1 = N0 + 1
+ N0 = (1 - d) * N0
+ N1 = (1 - d) * N1 + d * nb_iterations
- N0 = (1 - deterministic) * N0
- N1 = deterministic * nb_iterations + (1 - deterministic) * N1
+ targets, input = degrade_input(
+ input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+ )
- # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}")
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # for n in ["input", "targets"]:
+ # filename = f"{n}.png"
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=locals()[n],
+ # )
+ # log_string(f"wrote {filename}")
+ # time.sleep(1000)
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- targets, input = degrade_input(input, mask_generate, (N0, N1))
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
local_device,
"test",
):
- deterministic = (
- mask_generate.sum(dim=1) < mask_generate.size(1) // 4
- ).long()
-
- N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+ d = deterministic = mask_generate
+ p = probs_iterations.expand(input.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=p)
+ N0 = dist.sample()
N1 = N0 + 1
-
- N0 = (1 - deterministic) * N0
- N1 = deterministic * nb_iterations + (1 - deterministic) * N1
-
- targets, input = degrade_input(input, mask_generate, (N0, N1))
+ N0 = (1 - d) * N0
+ N1 = (1 - d) * N1 + d * nb_iterations
+ targets, input = degrade_input(
+ input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+ )
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
)
targets = input.clone()
- input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations)
+ input = ae_generate(
+ model,
+ input,
+ mask_generate,
+ n_epoch,
+ nb_iterations,
+ noise_proba=noise_proba,
+ )
correct = (input == targets).min(dim=1).values.long()
predicted_parts = torch.tensor(quad_generate, device=input.device)
predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)