mask_loss.to(local_device),
)
-
-def degrade_input(input, mask_generate, *ts):
- noise = torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
-
- r = torch.rand(mask_generate.size(), device=mask_generate.device)
-
- result = []
-
- for t in ts:
- mask_diffusion_noise = mask_generate * (r <= t).long()
- x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
- result.append(x)
-
- return result
-
# quiz_machine.problem.save_quizzes_as_image(
# args.result_dir,
# filename="a.png",
return input
+def degrade_input(input, mask_generate, *phis):
+ noise = torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+
+ r = torch.rand(mask_generate.size(), device=mask_generate.device)
+
+ result = []
+
+ for phi in phis:
+ mask_diffusion_noise = mask_generate * (r <= phi).long()
+ x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+ result.append(x)
+
+ return result
+
+
def test_ae(local_device=main_device):
model = MyAttentionAE(
vocabulary_size=vocabulary_size,
nb_iterations = 10
+ def phi(rho):
+ # return (rho / nb_iterations)**2
+ return rho / nb_iterations
+
for n_epoch in range(args.nb_epochs):
# ----------------------
# Train
model.optimizer.zero_grad()
rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
- targets, input = degrade_input(
- input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
- )
+
+ targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1))
input_with_mask = NTC_channel_cat(input, mask_generate, rho)
output = model(input_with_mask)
rho = torch.randint(
nb_iterations, (input.size(0), 1), device=input.device
)
+
targets, input = degrade_input(
- input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
+ input, mask_generate, phi(rho), phi(rho + 1)
)
+
input_with_mask = NTC_channel_cat(input, mask_generate, rho)
output = model(input_with_mask)
loss = NTC_masked_cross_entropy(output, targets, mask_loss)