return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-def ae_generate(model, input, mask_generate, n_epoch, nb_iterations):
+def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
input = (1 - mask_generate) * input + mask_generate * noise
- for it in range(nb_iterations):
- rho = input.new_full((input.size(0),), nb_iterations - 1 - it)
- input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None])
+ for it in range(nb_iterations_max):
+ input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
dist = torch.distributions.categorical.Categorical(logits=logits)
+ pred_input = input.clone()
input = (1 - mask_generate) * input + mask_generate * dist.sample()
+ if (pred_input == input).min():
+ break
return input
nb_iterations = 10
- def phi(rho):
- # return (rho / nb_iterations)**2
- return rho / nb_iterations
-
for n_epoch in range(args.nb_epochs):
# ----------------------
# Train
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
-
- targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1))
-
- input_with_mask = NTC_channel_cat(input, mask_generate, rho)
- output = model(input_with_mask)
- loss = NTC_masked_cross_entropy(output, targets, mask_loss)
+ phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25)
+ targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+ input_with_mask = NTC_channel_cat(input, mask_generate)
+ logits = model(input_with_mask)
+ loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
local_device,
"test",
):
- rho = torch.randint(
- nb_iterations, (input.size(0), 1), device=input.device
+ phi = torch.rand((input.size(0), 1), device=input.device).clamp(
+ min=0.25
)
-
- targets, input = degrade_input(
- input, mask_generate, phi(rho), phi(rho + 1)
- )
-
- input_with_mask = NTC_channel_cat(input, mask_generate, rho)
- output = model(input_with_mask)
- loss = NTC_masked_cross_entropy(output, targets, mask_loss)
+ targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+ input_with_mask = NTC_channel_cat(input, mask_generate)
+ logits = model(input_with_mask)
+ loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)