parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
-parser.add_argument("--diffusion_delta", type=float, default=0.1)
+parser.add_argument("--diffusion_delta", type=float, default=0.05)
-parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
+parser.add_argument("--diffusion_epsilon", type=float, default=0.05)
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
-diffusion_M[1, 0, 1:] = args.diffusion_delta
-diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2)
+diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta
+diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1)
+
for k in range(1, N):
diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
#
-def sample_x_t_given_x_0_(x_0, t):
+def sample_x_t_given_x_0(x_0, t):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
r = torch.rand(x_0.size(), device=x_0.device)
proba_erased = 1 - (1 - args.diffusion_delta) ** t
return x_t
-def sample_x_t_given_x_0(x_0, t):
+def ___sample_x_t_given_x_0(x_0, t):
D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
mask = (x_0 < quiz_machine.problem.nb_colors).long()
probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
return y
-def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t):
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
r = prioritized_rand(x_0 != x_t)
mask_changes = (r <= args.diffusion_delta).long()
return x_t_minus_1
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
+def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
mask = (x_0 < quiz_machine.problem.nb_colors).long()
# i = x_0[n,s], j = x_t[n,s]
def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
- # noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
+ noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- x_t = (1 - mask_generate) * x_0 # + mask_generate * noise
+ x_t = (1 - mask_generate) * x_0 + mask_generate * noise
one_iteration_prediction = deterministic(mask_generate)[:, None]