return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
+######################################################################
+
+#
+# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with
+#
+# x_{t_k} ~ P(X_{t_k} | X_0=x_0)
+#
+
+
+def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations):
+ noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device)
+
+ r = torch.rand(mask_generate.size(), device=mask_generate.device)
+
+ result = []
+
+ for n in steps_nb_iterations:
+ proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
+ mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
+ x = (1 - mask_erased) * x0 + mask_erased * noise
+ result.append(x)
+
+ return result
+
+
+######################################################################
+
+# Given x_t and a mas
+
+
+def targets_and_logits(model, input, mask_generate, prompt_noise=0.0):
+ d = deterministic(mask_generate)
+
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.nb_diffusion_iterations, device=input.device
+ )
+
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(input.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+
+ # N0 = dist.sample()
+ # N1 = N0 + 1
+ # N0 = (1 - d) * N0
+ # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
+
+ N0 = input.new_zeros(input.size(0))
+ N1 = dist.sample() + 1
+
+ targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+
+ if prompt_noise > 0:
+ mask_prompt_noise = (
+ torch.rand(input.size(), device=input.device) <= prompt_noise
+ ).long()
+ noise = torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
+ input = (1 - mask_generate) * noisy_input + mask_generate * input
+
+ input_with_mask = NTC_channel_cat(input, mask_generate)
+ logits = model(input_with_mask)
+
+ return targets, logits
+
+
+######################################################################
+
# This function returns a 2d tensor of same shape as low, full of
# uniform random values in [0,1], such that, in every row, the values
# corresponding to the True in low are all lesser than the values
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_prediction(
+ targets, logits = targets_and_logits(
model, q, mask_generate, prompt_noise=args.prompt_noise
)
loss_per_token = F.cross_entropy(
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_prediction(
+ targets, logits = targets_and_logits(
model, q, mask_generate, prompt_noise=args.prompt_noise
)
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_prediction(
+ targets, logits = targets_and_logits(
model, q, mask_generate, prompt_noise=args.prompt_noise
)
######################################################################
-def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
- noise = torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
-
- r = torch.rand(mask_generate.size(), device=mask_generate.device)
-
- result = []
-
- for n in steps_nb_iterations:
- proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
- mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
- x = (1 - mask_erased) * input + mask_erased * noise
- result.append(x)
-
- return result
-
-
-def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
- d = deterministic(mask_generate)
-
- probs_iterations = 0.1 ** torch.linspace(
- 0, 1, args.nb_diffusion_iterations, device=input.device
- )
-
- probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
- probs_iterations = probs_iterations.expand(input.size(0), -1)
- dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
- # N0 = dist.sample()
- # N1 = N0 + 1
- # N0 = (1 - d) * N0
- # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
-
- N0 = input.new_zeros(input.size(0))
- N1 = dist.sample() + 1
-
- targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
-
- if prompt_noise > 0:
- mask_prompt_noise = (
- torch.rand(input.size(), device=input.device) <= prompt_noise
- ).long()
- noise = torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
- noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
- input = (1 - mask_generate) * noisy_input + mask_generate * input
-
- input_with_mask = NTC_channel_cat(input, mask_generate)
- logits = model(input_with_mask)
-
- return targets, logits
-
-
-######################################################################
-
-
def run_ae_test(
model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
):
c_quizzes=c_quizzes,
desc="test",
):
- targets, logits = targets_and_prediction(model, input, mask_generate)
+ targets, logits = targets_and_logits(model, input, mask_generate)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
)
- if prefix is None:
- model.test_accuracy = nb_correct / nb_total
+ model.test_accuracy = nb_correct / nb_total
# Save some images
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets, logits = targets_and_prediction(
+ targets, logits = targets_and_logits(
model, input, mask_generate, prompt_noise=args.prompt_noise
)