#
-def degrade_input_to_generate(x_0, steps_nb_iterations):
+def sample_x_t_given_x_0(x_0, steps_nb_iterations):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
r = torch.rand(x_0.size(), device=x_0.device)
return result
+# This function returns a 2d tensor of same shape as low, full of
+# uniform random values in [0,1], such that, in every row, the values
+# corresponding to the True in low are all lesser than the values
+# corresponding to the False.
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
+ r = prioritized_rand(x_0 != x_t)
+
+ mask_changes = (r <= args.diffusion_noise_proba).long()
+
+ x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
+
+ return result
+
+
######################################################################
+# This function gets a clean target x_0, and a mask indicating which
+# part to generate (conditionnaly to the others), and returns the
+# logits starting from a x_t|X_0=x_0 picked at random with t random
+
def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
# We favor iterations near the clean signal
probs_iterations = probs_iterations.expand(x_0.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
- N1 = dist.sample() + 1
+ t_1 = dist.sample() + 1
- (x_t,) = degrade_input_to_generate(x_0, (N1,))
+ (x_t,) = sample_x_t_given_x_0(x_0, (t_1,))
# Only the part to generate is degraded, the rest is a perfect
# noise-free conditionning
######################################################################
-# This function returns a 2d tensor of same shape as low, full of
-# uniform random values in [0,1], such that, in every row, the values
-# corresponding to the True in low are all lesser than the values
-# corresponding to the False.
-
-
-def prioritized_rand(low):
- x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
- k = torch.rand(low.size(), device=low.device) + low.long()
- k = k.sort(dim=1).indices
- y = x.new(x.size())
- y.scatter_(dim=1, index=k, src=x)
- return y
-
def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
- r = prioritized_rand(hat_x_0 != x_t)
-
- mask_changes = (r <= args.diffusion_noise_proba).long()
-
- hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
+ hat_x_t_minus_1 = one_iteration_prediction * x_0 + (
1 - one_iteration_prediction
- ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0)
+ ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
if hat_x_t_minus_1.equal(x_t):
# log_string(f"exit after {it+1} iterations")