return x_t
-def ___sample_x_t_given_x_0(x_0, t):
- D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
- mask = (x_0 < quiz_machine.problem.nb_colors).long()
- probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
- dist = torch.distributions.categorical.Categorical(probs=probas)
- x_t = (1 - mask) * x_0 + mask * dist.sample()
- return x_t
-
-
# This function returns a 2d tensor of same shape as low, full of
# uniform random values in [0,1], such that, in every row, the values
# corresponding to the True in low are all lesser than the values
return x_t_minus_1
+# Non-uniform transitions, to be fixed?
+
+
+def ___sample_x_t_given_x_0(x_0, t):
+ D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
+ mask = (x_0 < quiz_machine.problem.nb_colors).long()
+ probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
+ dist = torch.distributions.categorical.Categorical(probs=probas)
+ x_t = (1 - mask) * x_0 + mask * dist.sample()
+ return x_t
+
+
def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
mask = (x_0 < quiz_machine.problem.nb_colors).long()